Repository: yinlou/mltk Branch: master Commit: f50c42986fdd Files: 194 Total size: 810.9 KB Directory structure: gitextract_ke00o21_/ ├── .gitignore ├── LICENSE ├── README.md ├── pom.xml └── src/ ├── main/ │ └── java/ │ └── mltk/ │ ├── cmdline/ │ │ ├── Argument.java │ │ ├── CmdLineParser.java │ │ ├── options/ │ │ │ ├── HoldoutValidatedLearnerOptions.java │ │ │ ├── HoldoutValidatedLearnerWithTaskOptions.java │ │ │ ├── LearnerOptions.java │ │ │ ├── LearnerWithTaskOptions.java │ │ │ └── package-info.java │ │ └── package-info.java │ ├── core/ │ │ ├── Attribute.java │ │ ├── BinnedAttribute.java │ │ ├── Bins.java │ │ ├── Copyable.java │ │ ├── DenseVector.java │ │ ├── Instance.java │ │ ├── Instances.java │ │ ├── NominalAttribute.java │ │ ├── NumericalAttribute.java │ │ ├── Sampling.java │ │ ├── SparseVector.java │ │ ├── Vector.java │ │ ├── Writable.java │ │ ├── io/ │ │ │ ├── AttributesReader.java │ │ │ ├── InstancesReader.java │ │ │ ├── InstancesWriter.java │ │ │ └── package-info.java │ │ ├── package-info.java │ │ └── processor/ │ │ ├── Discretizer.java │ │ ├── InstancesSplitter.java │ │ ├── OneHotEncoder.java │ │ └── package-info.java │ ├── feature/ │ │ └── selection/ │ │ ├── BackwardElimination.java │ │ └── package-info.java │ ├── predictor/ │ │ ├── BaggedEnsemble.java │ │ ├── BaggedEnsembleLearner.java │ │ ├── BoostedEnsemble.java │ │ ├── Classifier.java │ │ ├── Ensemble.java │ │ ├── Family.java │ │ ├── HoldoutValidatedLearner.java │ │ ├── Learner.java │ │ ├── LinkFunction.java │ │ ├── Predictor.java │ │ ├── ProbabilisticClassifier.java │ │ ├── Regressor.java │ │ ├── evaluation/ │ │ │ ├── AUC.java │ │ │ ├── ConvergenceTester.java │ │ │ ├── Error.java │ │ │ ├── Evaluator.java │ │ │ ├── LogLoss.java │ │ │ ├── LogisticLoss.java │ │ │ ├── MAE.java │ │ │ ├── Metric.java │ │ │ ├── MetricFactory.java │ │ │ ├── Predictor.java │ │ │ ├── RMSE.java │ │ │ ├── SimpleMetric.java │ │ │ └── package-info.java │ │ ├── function/ │ │ │ ├── Array1D.java │ │ │ ├── Array2D.java │ │ │ ├── BaggedLineCutter.java │ │ │ ├── BivariateFunction.java │ │ │ ├── CHistogram.java │ │ │ ├── CompressionUtils.java │ │ │ ├── CubicSpline.java │ │ │ ├── EnsembledLineCutter.java │ │ │ ├── Function1D.java │ │ │ ├── Function2D.java │ │ │ ├── Histogram2D.java │ │ │ ├── LineCutter.java │ │ │ ├── LinearFunction.java │ │ │ ├── SquareCutter.java │ │ │ ├── SubagSequence.java │ │ │ ├── SubaggedLineCutter.java │ │ │ ├── UnivariateFunction.java │ │ │ └── package-info.java │ │ ├── gam/ │ │ │ ├── DenseDesignMatrix.java │ │ │ ├── GA2MLearner.java │ │ │ ├── GAM.java │ │ │ ├── GAMLearner.java │ │ │ ├── GAMUtils.java │ │ │ ├── SPLAMLearner.java │ │ │ ├── ScorecardModelLearner.java │ │ │ ├── SparseDesignMatrix.java │ │ │ ├── interaction/ │ │ │ │ ├── FAST.java │ │ │ │ └── package-info.java │ │ │ ├── package-info.java │ │ │ └── tool/ │ │ │ ├── Diagnostics.java │ │ │ ├── Visualizer.java │ │ │ └── package-info.java │ │ ├── glm/ │ │ │ ├── ElasticNetLearner.java │ │ │ ├── GLM.java │ │ │ ├── GLMLearner.java │ │ │ ├── GLMOptimUtils.java │ │ │ ├── GroupLassoLearner.java │ │ │ ├── LassoLearner.java │ │ │ ├── RidgeLearner.java │ │ │ └── package-info.java │ │ ├── io/ │ │ │ ├── PredictorReader.java │ │ │ ├── PredictorWriter.java │ │ │ └── package-info.java │ │ ├── package-info.java │ │ └── tree/ │ │ ├── DecisionTable.java │ │ ├── DecisionTableLearner.java │ │ ├── RTree.java │ │ ├── RTreeLearner.java │ │ ├── RegressionTree.java │ │ ├── RegressionTreeLeaf.java │ │ ├── RegressionTreeLearner.java │ │ ├── TreeInteriorNode.java │ │ ├── TreeLearner.java │ │ ├── TreeNode.java │ │ ├── ensemble/ │ │ │ ├── BaggedRTrees.java │ │ │ ├── BoostedDTables.java │ │ │ ├── BoostedRTrees.java │ │ │ ├── RTreeList.java │ │ │ ├── TreeEnsembleLearner.java │ │ │ ├── ag/ │ │ │ │ ├── AdditiveGroves.java │ │ │ │ ├── AdditiveGrovesLearner.java │ │ │ │ └── package-info.java │ │ │ ├── brt/ │ │ │ │ ├── BDT.java │ │ │ │ ├── BRT.java │ │ │ │ ├── BRTLearner.java │ │ │ │ ├── BRTUtils.java │ │ │ │ ├── LADBoostLearner.java │ │ │ │ ├── LSBoostLearner.java │ │ │ │ ├── LogitBoostLearner.java │ │ │ │ ├── RobustDecisionTableLearner.java │ │ │ │ ├── RobustRegressionTreeLearner.java │ │ │ │ └── package-info.java │ │ │ ├── package-info.java │ │ │ └── rf/ │ │ │ ├── RandomForest.java │ │ │ ├── RandomForestLearner.java │ │ │ ├── RandomRegressionTreeLearner.java │ │ │ └── package-info.java │ │ └── package-info.java │ └── util/ │ ├── ArrayUtils.java │ ├── Element.java │ ├── MathUtils.java │ ├── OptimUtils.java │ ├── Permutation.java │ ├── Queue.java │ ├── Random.java │ ├── Stack.java │ ├── StatUtils.java │ ├── UFSets.java │ ├── VectorUtils.java │ ├── package-info.java │ └── tuple/ │ ├── DoublePair.java │ ├── IntDoublePair.java │ ├── IntDoublePairComparator.java │ ├── IntPair.java │ ├── IntTriple.java │ ├── LongDoublePair.java │ ├── LongDoublePairComparator.java │ ├── LongPair.java │ ├── Pair.java │ ├── Triple.java │ └── package-info.java └── test/ └── java/ └── mltk/ ├── core/ │ ├── BinsTest.java │ ├── InstancesTestHelper.java │ ├── io/ │ │ ├── AttributesReaderTest.java │ │ └── InstancesReaderTest.java │ └── processor/ │ ├── DiscretizerTest.java │ └── InstancesSplitterTest.java ├── predictor/ │ ├── evaluation/ │ │ ├── AUCTest.java │ │ ├── ConvergenceTesterTest.java │ │ ├── ErrorTest.java │ │ ├── LogLossTest.java │ │ ├── LogisticLossTest.java │ │ ├── MAETest.java │ │ ├── MetricFactoryTest.java │ │ └── RMSETest.java │ ├── glm/ │ │ └── GLMTest.java │ └── tree/ │ ├── DecisionTableLearnerTest.java │ ├── DecisionTableTest.java │ ├── DecisionTableTestHelper.java │ ├── RegressionTreeLearnerTest.java │ ├── RegressionTreeTest.java │ ├── RegressionTreeTestHelper.java │ └── ensemble/ │ ├── BoostedDTablesTest.java │ ├── BoostedRTreesTest.java │ └── brt/ │ ├── BDTTest.java │ ├── BRTTest.java │ ├── BRTUtilsTest.java │ └── LogitBoostLearnerTest.java └── util/ ├── ArrayUtilsTest.java ├── MathUtilsTest.java ├── OptimUtilsTest.java ├── StatUtilsTest.java └── VectorUtilsTest.java ================================================ FILE CONTENTS ================================================ ================================================ FILE: .gitignore ================================================ *~ *.swp *.class # osx .DS_Store # Eclipse build artifacts .project .classpath # Maven target build/ **/build/ classes # IntelliJ .idea/ *.iml *.ipr *.iws **/*.iml out/ **/.classpath **/.project .settings/ **/.settings/ ================================================ FILE: LICENSE ================================================ Copyright (c) 2012-2019, Yin Lou All rights reserved. Redistribution and use in source and binary forms, with or without modification, are permitted provided that the following conditions are met: * Redistributions of source code must retain the above copyright notice, this list of conditions and the following disclaimer. * Redistributions in binary form must reproduce the above copyright notice, this list of conditions and the following disclaimer in the documentation and/or other materials provided with the distribution. * Neither the name of the organization nor the names of its contributors may be used to endorse or promote products derived from this software without specific prior written permission. THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. ================================================ FILE: README.md ================================================ # Machine Learning Tool Kit MLTK is a collection of various supervised machine learning algorithms, which is designed for directly training models and further development. For questions or suggestions with the code, please email machinelearningtoolkit@gmail.com. See [wiki](https://github.com/yinlou/mltk/wiki) for full documentation, examples and other information. ================================================ FILE: pom.xml ================================================ 4.0.0 com.github.yinlou mltk 0.1.1-SNAPSHOT mltk Machine Learning Tool Kit https://github.com/yinlou/mltk BSD 3-Clause https://opensource.org/licenses/BSD-3-Clause Yin Lou machinelearningtoolkit@gmail.com scm:git:git://github.com/yinlou/mltk.git scm:git:ssh://github.com:yinlou/mltk.git https://github.com/yinlou/mltk.git UTF-8 UTF-8 1.8 1.8 doclint-java8-disable [1.8,) -Xdoclint:none ossrh https://oss.sonatype.org/content/repositories/snapshots org.sonatype.plugins nexus-staging-maven-plugin 1.6.7 true ossrh https://oss.sonatype.org/ true org.apache.maven.plugins maven-compiler-plugin 3.5.1 true ${maven.compiler.source} ${maven.compiler.target} org.apache.maven.plugins maven-source-plugin 2.2.1 attach-sources jar-no-fork org.apache.maven.plugins maven-javadoc-plugin 2.10.3 8 Machine Learning Tool Kit ${project.version} API Machine Learning Tool Kit ${project.version} API ${javadoc.opts} attach-javadocs package jar org.apache.maven.plugins maven-gpg-plugin 1.5 sign-artifacts verify sign junit junit 4.13.1 ================================================ FILE: src/main/java/mltk/cmdline/Argument.java ================================================ package mltk.cmdline; import java.lang.annotation.ElementType; import java.lang.annotation.Retention; import java.lang.annotation.Target; /** * Command line argument. * * @author Yin Lou * */ @Retention(java.lang.annotation.RetentionPolicy.RUNTIME) @Target(ElementType.FIELD) public @interface Argument { /** * Name of this argument. * * @return the name of this argument. */ String name() default ""; /** * Description of this argument. * * @return the description of this argument. */ String description() default ""; /** * Whether this argument is required. * * @return {@code true} if the argument is required. */ boolean required() default false; } ================================================ FILE: src/main/java/mltk/cmdline/CmdLineParser.java ================================================ package mltk.cmdline; import java.lang.reflect.Field; import java.util.ArrayList; import java.util.HashMap; import java.util.List; import java.util.Map; /** * Class for command line parser. * * @author Yin Lou * */ public class CmdLineParser { private String name; private Object obj; private List argList; private List fieldList; /** * Constructor. * * @param clazz the class object. * @param obj the object. */ public CmdLineParser(Class clazz, Object obj) { this.name = clazz.getCanonicalName(); this.obj = obj; argList = new ArrayList<>(); fieldList = new ArrayList<>(); processFields(obj.getClass().getFields()); processFields(obj.getClass().getDeclaredFields()); } /** * Parses the command line arguments. * * @param args the command line arguments. * @throws IllegalArgumentException * @throws IllegalAccessException */ public void parse(String[] args) throws IllegalArgumentException, IllegalAccessException { if (args.length % 2 != 0) { throw new IllegalArgumentException(); } Map map = new HashMap<>(); for (int i = 0; i < args.length; i += 2) { map.put(args[i], args[i + 1]); } for (int i = 0; i < argList.size(); i++) { Field field = fieldList.get(i); Argument arg = argList.get(i); String value = map.get(arg.name()); if (value != null) { Class fclass = field.getType(); field.setAccessible(true); if (fclass == String.class) { field.set(obj, value); } else if (fclass == int.class) { field.setInt(obj, Integer.parseInt(value)); } else if (fclass == double.class) { field.setDouble(obj, Double.parseDouble(value)); } else if (fclass == float.class) { field.setFloat(obj, Float.parseFloat(value)); } else if (fclass == boolean.class) { field.setBoolean(obj, Boolean.parseBoolean(value)); } else if (fclass == long.class) { field.setLong(obj, Long.parseLong(value)); } else if (fclass == char.class) { field.setChar(obj, value.charAt(0)); } } else if (arg.required()) { throw new IllegalArgumentException(); } } } /** * Prints the generated usage. */ public void printUsage() { StringBuilder sb = new StringBuilder(); sb.append("Usage: ").append(name).append("\n"); StringBuilder required = new StringBuilder(); StringBuilder optional = new StringBuilder(); for (Argument arg : argList) { if (arg.required()) { required.append(arg.name()).append("\t").append(arg.description()).append("\n"); } else { optional.append("[").append(arg.name()).append("]\t").append(arg.description()).append("\n"); } } sb.append(required).append(optional); System.err.println(sb.toString()); } private void processFields(Field[] fields) { for (Field field : fields) { Argument argument = field.getAnnotation(Argument.class); if (argument != null) { fieldList.add(field); argList.add(argument); } } } } ================================================ FILE: src/main/java/mltk/cmdline/options/HoldoutValidatedLearnerOptions.java ================================================ package mltk.cmdline.options; import mltk.cmdline.Argument; public class HoldoutValidatedLearnerOptions extends LearnerOptions { @Argument(name = "-v", description = "valid set path") public String validPath = null; @Argument(name = "-e", description = "evaluation metric (default: default metric of task)") public String metric = null; @Argument(name = "-S", description = "convergence criteria (default: -1) ") public String cc = "-1"; } ================================================ FILE: src/main/java/mltk/cmdline/options/HoldoutValidatedLearnerWithTaskOptions.java ================================================ package mltk.cmdline.options; import mltk.cmdline.Argument; public class HoldoutValidatedLearnerWithTaskOptions extends HoldoutValidatedLearnerOptions { @Argument(name = "-g", description = "task between classification (c) and regression (r) (default: r)") public String task = "r"; } ================================================ FILE: src/main/java/mltk/cmdline/options/LearnerOptions.java ================================================ package mltk.cmdline.options; import mltk.cmdline.Argument; public class LearnerOptions { @Argument(name = "-r", description = "attribute file path") public String attPath = null; @Argument(name = "-t", description = "train set path", required = true) public String trainPath = null; @Argument(name = "-o", description = "output model path") public String outputModelPath = null; @Argument(name = "-V", description = "verbose (default: true)") public boolean verbose = true; } ================================================ FILE: src/main/java/mltk/cmdline/options/LearnerWithTaskOptions.java ================================================ package mltk.cmdline.options; import mltk.cmdline.Argument; public class LearnerWithTaskOptions extends LearnerOptions { @Argument(name = "-g", description = "task between classification (c) and regression (r) (default: r)") public String task = "r"; } ================================================ FILE: src/main/java/mltk/cmdline/options/package-info.java ================================================ /** * Provides classes for command line options. * */ package mltk.cmdline.options; ================================================ FILE: src/main/java/mltk/cmdline/package-info.java ================================================ /** * Provides classes for command line parser. */ package mltk.cmdline; ================================================ FILE: src/main/java/mltk/core/Attribute.java ================================================ package mltk.core; /** * Abstract class for attributes. * * @author Yin Lou * */ public abstract class Attribute implements Comparable, Copyable { public enum Type { /** * Nominal type. */ NOMINAL, /** * Numeric type. */ NUMERIC, /** * Binned type. */ BINNED; } protected Type type; protected int index; protected String name; /** * Returns the type of this attribute. * * @return the type of this attribute. */ public final Type getType() { return type; } /** * Returns the index of this attribute. * * @return the index of this attribute. */ public final int getIndex() { return index; } /** * Sets the index of this attribute. * * @param index the new index. */ public final void setIndex(int index) { this.index = index; } /** * Returns the name of this attribute. * * @return the name of this attribute. */ public final String getName() { return name; } @Override public int compareTo(Attribute att) { return (this.index - att.index); } @Override public int hashCode() { final int prime = 31; int result = 1; result = prime * result + index; result = prime * result + ((name == null) ? 0 : name.hashCode()); result = prime * result + ((type == null) ? 0 : type.hashCode()); return result; } @Override public boolean equals(Object obj) { if (this == obj) return true; if (obj == null) return false; if (getClass() != obj.getClass()) return false; Attribute other = (Attribute) obj; if (index != other.index) return false; if (name == null) { if (other.name != null) return false; } else if (!name.equals(other.name)) return false; if (type != other.type) return false; return true; } @Override public String toString() { return name; } } ================================================ FILE: src/main/java/mltk/core/BinnedAttribute.java ================================================ package mltk.core; import java.util.Arrays; import mltk.util.ArrayUtils; /** * Class for discretized attributes. * * @author Yin Lou * */ public class BinnedAttribute extends Attribute { protected int numBins; protected Bins bins; /** * Constructor. * * @param name the name of this attribute. * @param numBins number of bins for this attribute. */ public BinnedAttribute(String name, int numBins) { this(name, numBins, -1); } /** * Constructor. * * @param name the name of this attribute. * @param numBins number of bins for this attribute. * @param index the index of this attribute. */ public BinnedAttribute(String name, int numBins, int index) { this.name = name; this.numBins = numBins; this.bins = null; this.index = index; this.type = Type.BINNED; } /** * Constructor. * * @param name the name of this attribute. * @param bins bins for this attribute. */ public BinnedAttribute(String name, Bins bins) { this(name, bins, -1); } /** * Constructor. * * @param name the name of this attribute. * @param bins bins for this attribute. * @param index the index of this attribute. */ public BinnedAttribute(String name, Bins bins, int index) { this(name, bins.size()); this.bins = bins; this.index = index; } @Override public BinnedAttribute copy() { BinnedAttribute copy = (bins == null ? new BinnedAttribute(name, numBins) : new BinnedAttribute(name, bins)); copy.index = index; return copy; } /** * Returns the number of bins. * * @return the number of bins. */ public int getNumBins() { return numBins; } /** * Returns the bins. * * @return the bins. */ public Bins getBins() { return bins; } public String toString() { if (bins == null) { return name + ": binned (" + numBins + ")"; } else { return name + ": binned (" + bins.size() + ";" + Arrays.toString(bins.boundaries) + ";" + Arrays.toString(bins.medians) + ")"; } } /** * Parses a binned attribute object from a string. * * @param str the string. * @return a parsed binned attribute. */ public static BinnedAttribute parse(String str) { String[] data = str.split(": "); int start = data[1].indexOf('(') + 1; int end = data[1].indexOf(')'); String[] strs = data[1].substring(start, end).split(";"); int numBins = Integer.parseInt(strs[0]); if (strs.length == 1) { return new BinnedAttribute(data[0], numBins); } else { double[] boundaries = ArrayUtils.parseDoubleArray(strs[1]); double[] medians = ArrayUtils.parseDoubleArray(strs[2]); Bins bins = new Bins(boundaries, medians); return new BinnedAttribute(data[0], bins); } } } ================================================ FILE: src/main/java/mltk/core/Bins.java ================================================ package mltk.core; import mltk.util.ArrayUtils; /** * Class for bins. Each bin is defined as its upper bound and median. * * @author Yin Lou * */ public class Bins { /** * The upper bounds for each bin */ protected double[] boundaries; /** * The medians for each bin */ protected double[] medians; protected Bins() { } /** * Constructor. * * @param boundaries the uppber bounds for each bin. * @param medians the medians for each bin. */ public Bins(double[] boundaries, double[] medians) { if (boundaries.length != medians.length) { throw new IllegalArgumentException("Boundary size doesn't match medians size"); } this.boundaries = boundaries; this.medians = medians; } /** * Returns the number of bins. * * @return the number of bins. */ public int size() { return boundaries.length; } /** * Returns the bin index given a real value using binary search. * * @param value the real value to discretize. * @return the discretized index. */ public int getIndex(double value) { if (value < boundaries[0]) { return 0; } else if (value >= boundaries[boundaries.length - 1]) { return boundaries.length - 1; } else { return ArrayUtils.findInsertionPoint(boundaries, value); } } /** * Returns the median of a bin. * * @param index the index of the bin. * @return the median of the bin. */ public double getValue(int index) { return medians[index]; } /** * Returns the upper bounds for each bin. * * @return the upper bounds for each bin. */ public double[] getBoundaries() { return boundaries; } /** * Returns the medians for each bin. * * @return the medians for each bin. */ public double[] getMedians() { return medians; } } ================================================ FILE: src/main/java/mltk/core/Copyable.java ================================================ package mltk.core; /** * Copyable interface. * * @author Yin Lou * * @param the type of the object. */ public interface Copyable { /** * Returns a (deep) copy of the object. * * @return a (deep) copy of the object. */ public T copy(); } ================================================ FILE: src/main/java/mltk/core/DenseVector.java ================================================ package mltk.core; import java.util.Arrays; /** * Class for dense vectors. * * @author Yin Lou * */ public class DenseVector implements Vector { protected double[] values; /** * Constructs a dense vector from a double array. * * @param values the double array. */ public DenseVector(double[] values) { this.values = values; } @Override public double getValue(int index) { return values[index]; } @Override public double[] getValues() { return values; } @Override public double[] getValues(int... indices) { double[] values = new double[indices.length]; for (int i = 0; i < values.length; i++) { values[i] = getValue(indices[i]); } return values; } @Override public void setValue(int index, double value) { values[index] = value; } @Override public void setValue(int[] indices, double[] v) { for (int i = 0; i < indices.length; i++) { values[indices[i]] = v[i]; } } @Override public DenseVector copy() { double[] copyValues = Arrays.copyOf(values, values.length); return new DenseVector(copyValues); } @Override public boolean isSparse() { return false; } } ================================================ FILE: src/main/java/mltk/core/Instance.java ================================================ package mltk.core; import mltk.util.MathUtils; /** * Class for instances. * * @author Yin Lou * */ public class Instance implements Copyable { protected Vector vector; protected double[] target; protected double weight; /** * Constructs a dense instance from values, target and weight. * * @param values the values. * @param target the target. * @param weight the weight. */ public Instance(double[] values, double target, double weight) { this.vector = new DenseVector(values); this.target = new double[] { target }; this.weight = weight; } /** * Constructs a sparse instance from indices, values, target and weight. * * @param indices the indices. * @param values the values. * @param target the target. * @param weight the weight. */ public Instance(int[] indices, double[] values, double target, double weight) { this.vector = new SparseVector(indices, values); this.target = new double[] { target }; this.weight = weight; } /** * Constructs a dense instance from vector, target and weight. * * @param vector the vector. * @param target the target. * @param weight the weight. */ public Instance(Vector vector, double target, double weight) { this.vector = vector; this.target = new double[] { target }; this.weight = weight; } /** * Constructor with default weight 1.0. * * @param values the values. * @param target the target. */ public Instance(double[] values, double target) { this(values, target, 1.0); } /** * Constructor with default weight 1.0. * * @param indices the indices. * @param values the values. * @param target the target. */ public Instance(int[] indices, double[] values, double target) { this(indices, values, target, 1.0); } /** * Construct with default weight 1.0. * * @param vector the vector. * @param target the target. */ public Instance(Vector vector, double target) { this(vector, target, 1.0); } /** * Constructor with default weight 1.0 and no target. * * @param values the values. */ public Instance(double[] values) { this(values, Double.NaN); } /** * Constructor with default weight 1.0 and no target. * * @param indices the indices. * @param values the values. */ public Instance(int[] indices, double[] values) { this(indices, values, Double.NaN); } /** * Constructor with default weight 1.0 and no target. * * @param vector the vector. * @param values the values. */ public Instance(Vector vector, double[] values) { this(vector, Double.NaN); } /** * Copy constructor. * * @param instance the other instance to copy. */ public Instance(Instance instance) { this.vector = instance.vector; this.weight = instance.weight; this.target = instance.target; } /** * Returns {@code true} if the instance is sparse. * * @return {@code true} if the instance is sparse. */ public boolean isSparse() { return vector.isSparse(); } /** * Returns the value at specified attribute. * * @param attIndex the attribute index. * @return the value at specified attribute. */ public final double getValue(int attIndex) { return vector.getValue(attIndex); } /** * Returns the values. * * @return the values. */ public final double[] getValues() { return vector.getValues(); } /** * Returns an array representation of values at specified attributes. * * @param attributes the attributes. * @return an array representation of values at specified attributes. */ public final double[] getValues(int... attributes) { return vector.getValues(attributes); } /** * Sets the value at specified attribute. * * @param attIndex the attribute index. * @param value the new value to set. */ public final void setValue(int attIndex, double value) { vector.setValue(attIndex, value); } /** * Sets the value at specified attribute. * * @param attribute the attribute. * @param value the new value to set. */ public final void setValue(Attribute attribute, double value) { setValue(attribute.getIndex(), value); } /** * Sets the values at specified attributes. * * @param attributes the attribute index array. * @param v the value array. */ public final void setValue(int[] attributes, double[] v) { for (int i = 0; i < attributes.length; i++) { setValue(attributes[i], v[i]); } } @Override public Instance copy() { Vector copyVector = vector.copy(); return new Instance(copyVector, target[0], weight); } /** * Returns a shallow copy. * * @return a shallow copy. */ public Instance clone() { return new Instance(this); } /** * Returns {@code true} if a specific attribute value is missing. * * @param attIndex the attribute index. * @return {@code true} if a specific attribute value is missing. */ public boolean isMissing(int attIndex) { return Double.isNaN(getValue(attIndex)); } /** * Returns the value at specified attribute. * * @param att the attribute object. * @return the value at specified attribute. */ public double getValue(Attribute att) { return getValue(att.getIndex()); } /** * Returns the vector. * * @return the vector. */ public Vector getVector() { return vector; } /** * Returns the weight of this instance. * * @return the weight of this instance. */ public double getWeight() { return weight; } /** * Sets the weight of this instance. * * @param weight the new weight of this instance. */ public void setWeight(double weight) { this.weight = weight; } /** * Returns the target value. * * @return the target value. */ public double getTarget() { return target[0]; } /** * Sets the target value. * * @param target the new target value. */ public void setTarget(double target) { this.target[0] = target; } /** * Returns the string representation of this instance. */ public String toString() { StringBuilder sb = new StringBuilder(); if (isSparse()) { sb.append(getTarget()); SparseVector sv = (SparseVector) vector; int[] indices = sv.getIndices(); double[] values = sv.getValues(); for (int i = 0; i < indices.length; i++) { sb.append(" ").append(indices[i]).append(":"); print(sb, values[i]); } } else { double[] values = getValues(); print(sb, values[0]); for (int i = 1; i < values.length; i++) { sb.append("\t"); print(sb, values[i]); } if (!Double.isNaN(getTarget())) { sb.append("\t"); print(sb, getTarget()); } } return sb.toString(); } protected void print(StringBuilder sb, double v) { if (MathUtils.isInteger(v)) { sb.append((int) v); } else { sb.append(v); } } } ================================================ FILE: src/main/java/mltk/core/Instances.java ================================================ package mltk.core; import java.util.ArrayList; import java.util.Collections; import java.util.Iterator; import java.util.List; import mltk.util.Random; /** * Class for handling an ordered set of instances. * * @author Yin Lou * */ public class Instances implements Iterable, Copyable { protected List attributes; protected List instances; protected Attribute targetAtt; /** * Constructs a dataset from attributes. * * @param attributes the attributes. */ public Instances(List attributes) { this(attributes, null); } /** * Constructs a dataset from attributes, with specified capacity. * * @param attributes the attributes. * @param capacity the capacity. */ public Instances(List attributes, int capacity) { this(attributes, null, capacity); } /** * Constructs a dataset from attributes and target attribute. * * @param attributes the attributes. * @param targetAtt the target attribute. */ public Instances(List attributes, Attribute targetAtt) { this(attributes, targetAtt, 1000); } /** * Constructs a dataset from attributes and target attribute, with specified capacity. * * @param attributes the attributes. * @param targetAtt the target attribute. * @param capacity the capacity. */ public Instances(List attributes, Attribute targetAtt, int capacity) { this.attributes = attributes; this.targetAtt = targetAtt; this.instances = new ArrayList<>(capacity); } /** * Copy constructor. * * @param instances the instances to copy. */ public Instances(Instances instances) { this.attributes = instances.attributes; this.targetAtt = instances.targetAtt; this.instances = new ArrayList<>(instances.instances); } /** * Adds an instance to the end of the dataset. * * @param instance the instance to add. */ public void add(Instance instance) { instances.add(instance); } /** * Returns the instance at given index. * * @param index the index. * @return the instance at given index. */ public Instance get(int index) { return instances.get(index); } /** * Returns the target attribute. * * @return the target attribute. */ public final Attribute getTargetAttribute() { return targetAtt; } /** * Sets the target attribute. * * @param targetAtt the target attribute. */ public final void setTargetAttribute(Attribute targetAtt) { this.targetAtt = targetAtt; } @Override public Iterator iterator() { return instances.iterator(); } /** * Returns the size of this dataset, i.e., the number of instances. * * @return the size of this dataset. */ public final int size() { return instances.size(); } /** * Returns the dimension of this dataset, i.e., the number of attributes. Note that class attribute does not count. * * @return the dimension of this dataset. */ public final int dimension() { return attributes.size(); } /** * Returns the list of attributes. * * @return the list of attributes. */ public List getAttributes() { return attributes; } /** * Returns the list of attributes at given locations. * * @param indices the indices. * @return the list of attributes at given locations. */ public List getAttributes(int... indices) { List attributes = new ArrayList<>(indices.length); for (int index : indices) { attributes.add(this.attributes.get(index)); } return attributes; } /** * Sets the attributes. * * @param attributes the attributes to set. */ public void setAttributes(List attributes) { this.attributes = attributes; } /** * Resets this dataset. */ public void clear() { instances.clear(); } /** * Randomly permutes this dataset. */ public void shuffle() { Collections.shuffle(instances, Random.getInstance().getRandom()); } /** * Randomly permutes this dataset. * * @param rand the source of randomness to use to shuffle the dataset. */ public void shuffle(java.util.Random rand) { Collections.shuffle(instances, rand); } @Override public Instances copy() { List attributes = new ArrayList<>(this.attributes); Instances copy = new Instances(attributes, targetAtt, instances.size()); for (Instance instance : instances) { copy.add(instance.copy()); } return copy; } } ================================================ FILE: src/main/java/mltk/core/NominalAttribute.java ================================================ package mltk.core; /** * Class for nominal attributes. * * @author Yin Lou * */ public class NominalAttribute extends Attribute { protected String[] states; /** * Constructor. * * @param name the name of this attribute. * @param states the states for this attribute. */ public NominalAttribute(String name, String[] states) { this(name, states, -1); } /** * Constructor. * * @param name the name of this attribute. * @param states the states for this attribute. * @param index the index of this attribute. */ public NominalAttribute(String name, String[] states, int index) { this.name = name; this.states = states; this.index = index; this.type = Type.NOMINAL; } public NominalAttribute copy() { NominalAttribute copy = new NominalAttribute(name, states); copy.index = this.index; return copy; } /** * Returns the cardinality of this attribute. * * @return the cardinality of this attribute. */ public int getCardinality() { return states.length; } /** * Returns the state given an index. * * @param index the index. * @return the state given an index. */ public String getState(int index) { return states[index]; } /** * Returns the states. * * @return the states. */ public String[] getStates() { return states; } public String toString() { StringBuilder sb = new StringBuilder(); sb.append(name).append(": {").append(states[0]); for (int i = 1; i < states.length; i++) { sb.append(", ").append(states[i]); } sb.append("}"); return sb.toString(); } /** * Parses a nominal attribute ojbect from a string. * * @param str the string. * @return a parsed nominal attribute. */ public static NominalAttribute parse(String str) { String[] data = str.split(": "); int start = data[1].indexOf('{') + 1; int end = data[1].indexOf('}'); String[] states = data[1].substring(start, end).split(","); for (int j = 0; j < states.length; j++) { states[j] = states[j].trim(); } return new NominalAttribute(data[0], states); } } ================================================ FILE: src/main/java/mltk/core/NumericalAttribute.java ================================================ package mltk.core; /** * Class for numerical attributes. * * @author Yin Lou * */ public class NumericalAttribute extends Attribute { /** * Constructor. * * @param name the name of this attribute. */ public NumericalAttribute(String name) { this(name, -1); } /** * Constructor. * * @param name the name of this attribute. * @param index the index of this attribute. */ public NumericalAttribute(String name, int index) { this.name = name; this.index = index; this.type = Type.NUMERIC; } public NumericalAttribute copy() { NumericalAttribute copy = new NumericalAttribute(this.name); copy.index = this.index; return copy; } public String toString() { return name + ": cont"; } /** * Parses a numerical attribute object from a string. * * @param str the string. * @return a parsed numerical attribute. */ public static NumericalAttribute parse(String str) { String[] data = str.split(": "); return new NumericalAttribute(data[0]); } } ================================================ FILE: src/main/java/mltk/core/Sampling.java ================================================ package mltk.core; import java.util.HashMap; import java.util.List; import java.util.Map; import mltk.util.Permutation; import mltk.util.Random; import mltk.util.tuple.IntPair; /** * Class for creating samples. * * @author Yin Lou * */ public class Sampling { /** * Returns a bootstrap sample. * * @param instances the data set. * @return a bootstrap sample. */ public static Instances createBootstrapSample(Instances instances) { Random rand = Random.getInstance(); Map map = new HashMap<>(); for (int i = 0; i < instances.size(); i++) { int idx = rand.nextInt(instances.size()); map.put(idx, map.getOrDefault(idx, 0) + 1); } Instances bag = new Instances(instances.getAttributes(), instances.getTargetAttribute(), map.size()); for (Integer idx : map.keySet()) { int weight = map.get(idx); Instance instance = instances.get(idx).clone(); instance.setWeight(weight); bag.add(instance); } return bag; } /** * Returns a bootstrap sample with out-of-bag samples. * * @param instances the data set. * @param bagIndices the index of sampled instances with weights. * @param oobIndices the out-of-bag indexes. */ public static void createBootstrapSample(Instances instances, Map bagIndices, List oobIndices) { Random rand = Random.getInstance(); for (;;) { bagIndices.clear(); oobIndices.clear(); for (int i = 0; i < instances.size(); i++) { int idx = rand.nextInt(instances.size()); bagIndices.put(idx, bagIndices.getOrDefault(idx, 0) + 1); } for (int i = 0; i < instances.size(); i++) { if (!bagIndices.containsKey(i)) { oobIndices.add(i); } } if (oobIndices.size() > 0) { break; } } } /** * Returns a bootstrap sample of indices and weights. * * @param n the size of the dataset to sample. * @return a bootstrap sample of indices and weights. */ public static IntPair[] createBootstrapSampleIndices(int n) { Random rand = Random.getInstance(); Map map = new HashMap<>(); for (int i = 0; i < n; i++) { int idx = rand.nextInt(n); map.put(idx, map.getOrDefault(idx, 0) + 1); } IntPair[] indices = new IntPair[map.size()]; int k = 0; for (Map.Entry entry : map.entrySet()) { indices[k++] = new IntPair(entry.getKey(), entry.getValue()); } return indices; } /** * Returns a subsample. * * @param instances the dataset. * @param n the sample size. * @return a subsample. */ public static Instances createSubsample(Instances instances, int n) { Permutation perm = new Permutation(instances.size()); perm.permute(); int[] a = perm.getPermutation(); Instances sample = new Instances(instances.getAttributes(), instances.getTargetAttribute(), n); for (int i = 0; i < n; i++) { sample.add(instances.get(a[i])); } return sample; } /** * Returns a set of bags. * * @param instances the dataset. * @param b the number of bagging iterations. * @return a set of bags. */ public static Instances[] createBags(Instances instances, int b) { Instances[] bags = null; if (b <= 0) { // No bagging bags = new Instances[] { instances }; } else { bags = new Instances[b]; for (int i = 0; i < b; i++) { bags[i] = Sampling.createBootstrapSample(instances); } } return bags; } } ================================================ FILE: src/main/java/mltk/core/SparseVector.java ================================================ package mltk.core; import java.util.Arrays; /** * Class for sparse vectors. * * @author Yin Lou * */ public class SparseVector implements Vector { protected int[] indices; protected double[] values; /** * Constructs a sparse vector from sparse-format arrays. * * @param indices the indices array. * @param values the values array. */ public SparseVector(int[] indices, double[] values) { this.indices = indices; this.values = values; } @Override public SparseVector copy() { int[] copyIndices = Arrays.copyOf(indices, indices.length); double[] copyValues = Arrays.copyOf(values, values.length); return new SparseVector(copyIndices, copyValues); } @Override public double getValue(int index) { int idx = Arrays.binarySearch(indices, index); if (idx >= 0) { return values[idx]; } else { return 0; } } @Override public double[] getValues() { return values; } /** * Returns the internal representation of indices. * * @return the internal representation of indices. */ public int[] getIndices() { return indices; } @Override public double[] getValues(int... indices) { double[] values = new double[indices.length]; for (int i = 0; i < values.length; i++) { values[i] = getValue(indices[i]); } return values; } @Override public void setValue(int index, double value) { int idx = Arrays.binarySearch(indices, index); if (idx >= 0) { values[idx] = value; } else { throw new UnsupportedOperationException(); } } @Override public void setValue(int[] indices, double[] v) { for (int i = 0; i < indices.length; i++) { setValue(indices[i], v[i]); } } @Override public boolean isSparse() { return true; } } ================================================ FILE: src/main/java/mltk/core/Vector.java ================================================ package mltk.core; /** * Interface for vectors. * * @author Yin Lou * */ public interface Vector extends Copyable { /** * Returns the value at specified index. * * @param index the index. * @return the value at specified index. */ public double getValue(int index); /** * Returns the internal representation of values. * * @return the internal representation of values. */ public double[] getValues(); /** * Returns an array representation of values at specified indices. * * @param indices the indices. * @return an array representation of values at specified indices. */ public double[] getValues(int... indices); /** * Sets the value at specified index. * * @param index the index. * @param value the new value to set. */ public void setValue(int index, double value); /** * Sets the values at specified indices. * * @param indices the index array. * @param v the value array. */ public void setValue(int[] indices, double[] v); /** * Returns {@code true} if the vector is sparse. * * @return {@code true} if the vector is sparse. */ public boolean isSparse(); /** * Returns a (deep) copy of the vector. * * @return a (deep) copy of the vector. */ public Vector copy(); } ================================================ FILE: src/main/java/mltk/core/Writable.java ================================================ package mltk.core; import java.io.BufferedReader; import java.io.PrintWriter; /** * Writable interface. * * @author Yin Lou * */ public interface Writable { /** * Reads in this object. * * @param in the reader. * @throws Exception */ void read(BufferedReader in) throws Exception; /** * Writes this object. * * @param out the writer. * @throws Exception */ void write(PrintWriter out) throws Exception; } ================================================ FILE: src/main/java/mltk/core/io/AttributesReader.java ================================================ package mltk.core.io; import java.io.BufferedReader; import java.io.FileReader; import java.io.IOException; import java.util.ArrayList; import java.util.HashSet; import java.util.List; import java.util.Set; import mltk.core.Attribute; import mltk.core.BinnedAttribute; import mltk.core.NominalAttribute; import mltk.core.NumericalAttribute; import mltk.util.tuple.Pair; /** * Class for reading attributes. It only reads in a list of attributes from the attribute file. * * @author Yin Lou * */ public class AttributesReader { /** * Reads attributes and class attribute from attribute file. * * @param attFile the attribute file. * @return a pair of attributes and target attribute (null if no target attribute). * @throws IOException */ public static Pair, Attribute> read(String attFile) throws IOException { BufferedReader br = new BufferedReader(new FileReader(attFile), 65535); Pair, Attribute> pair = read(br); br.close(); return pair; } /** * Reads attributes and class attribute from attribute file. * * @param br the reader. * @return a pair of attributes and target attribute (null if no target attribute). * @throws IOException */ public static Pair, Attribute> read(BufferedReader br) throws IOException { List attributes = new ArrayList(); Attribute targetAtt = null; Set usedNames = new HashSet<>(); for (int i = 0;; i++) { String line = br.readLine(); if (line == null) { break; } Attribute att = null; if (line.indexOf("binned") != -1) { att = BinnedAttribute.parse(line); } else if (line.indexOf("{") != -1) { att = NominalAttribute.parse(line); } else { att = NumericalAttribute.parse(line); } att.setIndex(i); if (line.indexOf(" (target)") != -1) { targetAtt = att; i--; } else { if (usedNames.contains(att.getName())) { throw new RuntimeException("Duplicate attribute name: " + att.getName()); } usedNames.add(att.getName()); attributes.add(att); } if (line.indexOf(" (x)") != -1) { att.setIndex(-1); } } return new Pair, Attribute>(attributes, targetAtt); } } ================================================ FILE: src/main/java/mltk/core/io/InstancesReader.java ================================================ package mltk.core.io; import java.io.BufferedReader; import java.io.FileReader; import java.io.IOException; import java.util.ArrayList; import java.util.List; import java.util.TreeSet; import mltk.core.Attribute; import mltk.core.Instance; import mltk.core.Instances; import mltk.core.NominalAttribute; import mltk.core.NumericalAttribute; import mltk.util.MathUtils; import mltk.util.tuple.Pair; /** * Class for reading instances. * * @author Yin Lou * */ public class InstancesReader { /** * Reads a set of instances from attribute file and data file. Attribute file can be null. Default delimiter is * whitespace. * * @param attFile the attribute file. * @param dataFile the data file. * @return a set of instances. * @throws IOException */ public static Instances read(String attFile, String dataFile) throws IOException { return read(attFile, dataFile, "\\s+"); } /** * Reads a set of instances from attribute file and data file. Attribute file can be null. * * @param attFile the attribute file. * @param dataFile the data file. * @param delimiter the delimiter. * @return a set of instances. * @throws IOException */ public static Instances read(String attFile, String dataFile, String delimiter) throws IOException { if (attFile != null) { Pair, Attribute> pair = AttributesReader.read(attFile); int classIndex = -1; if (pair.v2 != null) { classIndex = pair.v2.getIndex(); pair.v2.setIndex(-1); } List attributes = pair.v1; Instances instances = new Instances(attributes, pair.v2); int totalLength = instances.dimension(); if (classIndex != -1) { totalLength++; } BufferedReader br = new BufferedReader(new FileReader(dataFile), 65535); for (;;) { String line = br.readLine(); if (line == null) { break; } String[] data = line.split(delimiter); Instance instance = null; if (data.length >= 2 && data[1].indexOf(':') >= 0) { // Sparse instance instance = parseSparseInstance(data); } else if (data.length == totalLength) { // Dense instance instance = parseDenseInstance(data, classIndex); } else { System.err.println("Processed as dense vector but the number of attributes provided in the attribute file does not match with the number of columns in this row"); } if (instance != null) { instances.add(instance); } } br.close(); // Process skipped features for (int i = attributes.size() - 1; i >= 0; i--) { if (attributes.get(i).getIndex() < 0) { attributes.remove(i); } } return instances; } else { List attributes = new ArrayList<>(); Instances instances = new Instances(attributes); int totalLength = -1; TreeSet attrSet = new TreeSet<>(); BufferedReader br = new BufferedReader(new FileReader(dataFile), 65535); for (;;) { String line = br.readLine(); if (line == null) { break; } String[] data = line.split(delimiter); Instance instance = null; if (data.length >= 2 && data[1].indexOf(':') >= 0) { // Sparse instance instance = parseSparseInstance(data, attrSet); } else { // Dense instance if (totalLength == -1) { totalLength = data.length; } else if (data.length == totalLength) { instance = parseDenseInstance(data, -1); } } if (instance != null) { instances.add(instance); } } br.close(); if (totalLength == -1) { for (Integer attIndex : attrSet) { Attribute att = new NumericalAttribute("f" + attIndex); att.setIndex(attIndex); attributes.add(att); } } else { for (int j = 0; j < totalLength; j++) { Attribute att = new NumericalAttribute("f" + j); att.setIndex(j); attributes.add(att); } } assignTargetAttribute(instances); return instances; } } /** * Reads a set of dense instances from data file. Default delimiter is whitespace. * * @param file the data file. * @param targetIndex the index of the target attribute, -1 if no target attribute. * @return a set of dense instances. * @throws IOException */ public static Instances read(String file, int targetIndex) throws IOException { return read(file, targetIndex, "\\s+"); } /** * Reads a set of dense instances from data file. * * @param file the data file. * @param targetIndex the index of the target attribute, -1 if no target attribute. * @param delimiter the delimiter. * @return a set of dense instances. * @throws IOException */ public static Instances read(String file, int targetIndex, String delimiter) throws IOException { BufferedReader br = new BufferedReader(new FileReader(file), 65535); List attributes = new ArrayList<>(); Instances instances = new Instances(attributes); for (;;) { String line = br.readLine(); if (line == null) { break; } String[] data = line.split(delimiter); Instance instance = parseDenseInstance(data, targetIndex); instances.add(instance); } br.close(); int numAttributes = instances.get(0).getValues().length; for (int i = 0; i < numAttributes; i++) { Attribute att = new NumericalAttribute("f" + i); att.setIndex(i); attributes.add(att); } if (targetIndex >= 0) { assignTargetAttribute(instances); } return instances; } /** * Parses a dense instance from strings. * * @param data the string array. * @param classIndex the class index. * @return a dense instance from strings. */ static Instance parseDenseInstance(String[] data, int classIndex) { double classValue = Double.NaN; if (classIndex < 0) { double[] vector = new double[data.length]; for (int i = 0; i < data.length; i++) { vector[i] = parseDouble(data[i]); } return new Instance(vector, classValue); } else { double[] vector = new double[data.length - 1]; for (int i = 0; i < data.length; i++) { double value = parseDouble(data[i]); if (i < classIndex) { vector[i] = value; } else if (i > classIndex) { vector[i - 1] = value; } else { classValue = value; } } return new Instance(vector, classValue); } } /** * Parses a sparse instance from strings. * * @param data the string array. * @param attrSet the attributes set. * @return a sparse instance from strings. */ private static Instance parseSparseInstance(String[] data, TreeSet attrSet) { double targetValue = Double.parseDouble(data[0]); int[] indices = new int[data.length - 1]; double[] values = new double[data.length - 1]; for (int i = 0; i < indices.length; i++) { String[] pair = data[i + 1].split(":"); indices[i] = Integer.parseInt(pair[0]); values[i] = Double.parseDouble(pair[1]); attrSet.add(indices[i]); } return new Instance(indices, values, targetValue); } /** * Parses a sparse instance from strings. * * @param data the string array. * @return a sparse instance from strings. */ private static Instance parseSparseInstance(String[] data) { double classValue = Double.parseDouble(data[0]); int[] indices = new int[data.length - 1]; double[] values = new double[data.length - 1]; for (int i = 0; i < indices.length; i++) { String[] pair = data[i + 1].split(":"); indices[i] = Integer.parseInt(pair[0]); values[i] = Double.parseDouble(pair[1]); } return new Instance(indices, values, classValue); } /** * Assigns target attribute for a dataset. * * @param instances the data set. */ private static void assignTargetAttribute(Instances instances) { boolean isInteger = true; for (Instance instance : instances) { if (!MathUtils.isInteger(instance.getTarget())) { isInteger = false; break; } } if (isInteger) { TreeSet set = new TreeSet<>(); for (Instance instance : instances) { double target = instance.getTarget(); set.add((int) target); } String[] states = new String[set.size()]; int i = 0; for (Integer v : set) { states[i++] = v.toString(); } instances.setTargetAttribute(new NominalAttribute("target", states)); } else { instances.setTargetAttribute(new NumericalAttribute("target")); } } /** * Parses double value from a string. Missing value is supported. * * @param s the string to parse. * @return double value. */ private static double parseDouble(String s) { if (s.equals("?")) { return Double.NaN; } else { return Double.parseDouble(s); } } } ================================================ FILE: src/main/java/mltk/core/io/InstancesWriter.java ================================================ package mltk.core.io; import java.io.IOException; import java.io.PrintWriter; import java.util.List; import mltk.core.Attribute; import mltk.core.Instance; import mltk.core.Instances; /** * Class for writing instances. * * @author Yin Lou * */ public class InstancesWriter { /** * Writes a set of dense instances to attribute file and data file. * * @param instances the dense instances to write. * @param attFile the attribute file path. * @param dataFile the data file path. * @throws IOException */ public static void write(Instances instances, String attFile, String dataFile) throws IOException { List attributes = instances.getAttributes(); PrintWriter out = new PrintWriter(attFile); for (Attribute attribute : attributes) { out.println(attribute); } out.println(instances.getTargetAttribute() + " (target)"); out.flush(); out.close(); write(instances, dataFile); } /** * Writes a set of dense/sparse instances to data file. * * @param instances the dense instances to write. * @param file the data file path. * @throws IOException */ public static void write(Instances instances, String file) throws IOException { PrintWriter out = new PrintWriter(file); for (Instance instance : instances) { out.println(instance); } out.flush(); out.close(); } } ================================================ FILE: src/main/java/mltk/core/io/package-info.java ================================================ /** * Contains classes for reading and writing datasets. */ package mltk.core.io; ================================================ FILE: src/main/java/mltk/core/package-info.java ================================================ /** * Provides classes and interfaces for handling datasets and attributes. */ package mltk.core; ================================================ FILE: src/main/java/mltk/core/processor/Discretizer.java ================================================ package mltk.core.processor; import java.util.ArrayList; import java.util.Collections; import java.util.List; import mltk.cmdline.Argument; import mltk.cmdline.CmdLineParser; import mltk.core.Attribute; import mltk.core.BinnedAttribute; import mltk.core.Bins; import mltk.core.Instance; import mltk.core.Instances; import mltk.core.Attribute.Type; import mltk.core.io.AttributesReader; import mltk.core.io.InstancesReader; import mltk.core.io.InstancesWriter; import mltk.util.Element; import mltk.util.tuple.DoublePair; /** * Class for discretizers. * * @author Yin Lou * */ public class Discretizer { static class Options { @Argument(name = "-r", description = "attribute file path") String attPath = null; @Argument(name = "-t", description = "training file path") String trainPath = null; @Argument(name = "-i", description = "input dataset path", required = true) String inputPath = null; @Argument(name = "-d", description = "discretized attribute file path") String disAttPath = null; @Argument(name = "-m", description = "output attribute file path") String outputAttPath = null; @Argument(name = "-o", description = "output dataset path", required = true) String outputPath = null; @Argument(name = "-n", description = "maximum num of bins (default: 256)") int maxNumBins = 256; } /** * Discretizes datasets. * *
	 * Usage: mltk.core.processor.Discretizer
	 * -i	input dataset path
	 * -o	output dataset path
	 * [-r]	attribute file path
	 * [-t]	training file path
	 * [-d]	discretized attribute file path
	 * [-m]	output attribute file path
	 * [-n]	maximum num of bins (default: 256)
	 * 
* * @param args the command line arguments. * @throws Exception */ public static void main(String[] args) throws Exception { Options app = new Options(); CmdLineParser parser = new CmdLineParser(Discretizer.class, app); try { parser.parse(args); if (app.maxNumBins < 0) { throw new IllegalArgumentException(); } } catch (IllegalArgumentException e) { parser.printUsage(); System.exit(1); } List attributes = null; if (app.trainPath != null) { Instances trainSet = InstancesReader.read(app.attPath, app.trainPath); attributes = trainSet.getAttributes(); for (int i = 0; i < attributes.size(); i++) { Attribute attribute = attributes.get(i); if (attribute.getType() == Type.NUMERIC) { // Only discretize numeric attributes Discretizer.discretize(trainSet, i, app.maxNumBins); } } } else if (app.disAttPath != null) { attributes = AttributesReader.read(app.disAttPath).v1; } else { parser.printUsage(); System.exit(1); } Instances instances = InstancesReader.read(app.attPath, app.inputPath); List attrs = instances.getAttributes(); for (int i = 0; i < attrs.size(); i++) { Attribute attr = attrs.get(i); if (attr.getType() == Type.NUMERIC) { BinnedAttribute binnedAttr = (BinnedAttribute) attributes.get(i); // Only discretize numeric attributes Discretizer.discretize(instances, i, binnedAttr.getBins()); } } if (app.outputAttPath != null) { InstancesWriter.write(instances, app.outputAttPath, app.outputPath); } else { InstancesWriter.write(instances, app.outputPath); } } /** * Compute bins for a list of values. * * @param x the vector of input data. * @param maxNumBins the number of bins. * @return bins for a list of values. */ public static Bins computeBins(double[] x, int maxNumBins) { List> list = new ArrayList<>(); for (double v : x) { if (!Double.isNaN(v)) { list.add(new Element(1.0, v)); } } return computeBins(list, maxNumBins); } /** * Compute bins for a specified attribute. * * @param instances the dataset to discretize. * @param attIndex the attribute index. * @param maxNumBins the number of bins. * @return bins for a specified attribute. */ public static Bins computeBins(Instances instances, int attIndex, int maxNumBins) { Attribute attribute = instances.getAttributes().get(attIndex); List> list = new ArrayList<>(); for (Instance instance : instances) { if (!instance.isMissing(attribute.getIndex())) { list.add(new Element(instance.getWeight(), instance.getValue(attribute))); } } return computeBins(list, maxNumBins); } /** * Compute bins for a list of values. * * @param list the histogram. * @param maxNumBins the number of bins. * @return bins for a list of values. */ public static Bins computeBins(List> list, int maxNumBins) { Collections.sort(list); List stats = new ArrayList<>(); getStats(list, stats); if (stats.size() <= maxNumBins) { double[] a = new double[stats.size()]; for (int i = 0; i < a.length; i++) { a[i] = stats.get(i).v1; } return new Bins(a, a); } else { double totalWeight = 0; for (DoublePair stat : stats) { totalWeight += stat.v2; } double binSize = totalWeight / maxNumBins; List boundaryList = new ArrayList<>(); List medianList = new ArrayList<>(); int start = 0; double weight = 0; for (int i = 0; i < stats.size(); i++) { weight += stats.get(i).v2; totalWeight -= stats.get(i).v2; if (weight >= binSize) { if (i == start) { boundaryList.add(stats.get(start).v1); medianList.add(stats.get(start).v1); weight = 0; start = i + 1; } else { double d1 = weight - binSize; double d2 = stats.get(i).v2 - d1; if (d1 < d2) { boundaryList.add(stats.get(i).v1); medianList.add(getMedian(stats, start, weight / 2)); start = i + 1; weight = 0; } else { weight -= stats.get(i).v2; boundaryList.add(stats.get(i - 1).v1); medianList.add(getMedian(stats, start, weight / 2)); start = i; weight = stats.get(i).v2; } } binSize = (totalWeight + weight) / (maxNumBins - boundaryList.size()); } else if (i == stats.size() - 1) { boundaryList.add(stats.get(i).v1); medianList.add(getMedian(stats, start, weight / 2)); } } double[] boundaries = new double[boundaryList.size()]; double[] medians = new double[medianList.size()]; for (int i = 0; i < boundaries.length; i++) { boundaries[i] = boundaryList.get(i); medians[i] = medianList.get(i); } return new Bins(boundaries, medians); } } /** * Discretizes an attribute using bins. * * @param instances the dataset to discretize. * @param attIndex the attribute index. * @param bins the bins. */ public static void discretize(Instances instances, int attIndex, Bins bins) { Attribute attribute = instances.getAttributes().get(attIndex); BinnedAttribute binnedAttribute = new BinnedAttribute(attribute.getName(), bins); binnedAttribute.setIndex(attribute.getIndex()); instances.getAttributes().set(attIndex, binnedAttribute); for (Instance instance : instances) { if (!instance.isMissing(attribute.getIndex())) { int v = bins.getIndex(instance.getValue(attribute.getIndex())); instance.setValue(attribute.getIndex(), v); } } } /** * Discretized an attribute with specified number of bins. * * @param instances the dataset to discretize. * @param attIndex the attribute index. * @param maxNumBins the number of bins. */ public static void discretize(Instances instances, int attIndex, int maxNumBins) { Bins bins = computeBins(instances, attIndex, maxNumBins); discretize(instances, attIndex, bins); } static double getMedian(List stats, int start, double midPoint) { double weight = 0; for (int i = start; i < stats.size(); i++) { weight += stats.get(i).v2; if (weight >= midPoint) { return stats.get(i).v1; } } return stats.get((start + stats.size()) / 2).v1; } static void getStats(List> list, List stats) { if (list.size() == 0) { return; } double totalWeight = list.get(0).element; double lastValue = list.get(0).weight; for (int i = 1; i < list.size(); i++) { Element element = list.get(i); double value = element.weight; double weight = element.element; if (value != lastValue) { stats.add(new DoublePair(lastValue, totalWeight)); lastValue = value; totalWeight = weight; } else { totalWeight += weight; } } stats.add(new DoublePair(lastValue, totalWeight)); } /** * Constructor. */ public Discretizer() { } } ================================================ FILE: src/main/java/mltk/core/processor/InstancesSplitter.java ================================================ package mltk.core.processor; import java.io.File; import java.util.ArrayList; import java.util.List; import mltk.cmdline.Argument; import mltk.cmdline.CmdLineParser; import mltk.core.Attribute; import mltk.core.Attribute.Type; import mltk.core.BinnedAttribute; import mltk.core.Instance; import mltk.core.Instances; import mltk.core.NominalAttribute; import mltk.core.io.InstancesReader; import mltk.core.io.InstancesWriter; import mltk.util.MathUtils; import mltk.util.Random; import mltk.util.StatUtils; /** * Class for cross validation. * * @author Yin Lou * */ public class InstancesSplitter { static class Options { @Argument(name = "-r", description = "attribute file path") String attPath = null; @Argument(name = "-i", description = "input dataset path", required = true) String inputPath = null; @Argument(name = "-o", description = "output directory path", required = true) String outputDirPath = null; @Argument(name = "-m", description = "splitting mode:parameter. Splitting mode can be split (s) and cross validation (c) (default: c:5)") String crossValidationMode = "c:5"; @Argument(name = "-a", description = "attribute name to perform stratified sampling (default: null)") String attToStrafity = null; @Argument(name = "-s", description = "seed of the random number generator (default: 0)") long seed = 0L; } /** * Splits a dataset. * *
	 * Usage: mltk.core.processor.InstancesSplitter
	 * -i	input dataset path
	 * -o	output directory path
	 * [-r]	attribute file path
	 * [-m]	splitting mode:parameter. Splitting mode can be split (s) and cross validation (c) (default: c:5)
	 * [-a]	attribute name to perform stratified sampling (default: null)
	 * [-s]	seed of the random number generator (default: 0)
	 * 
* * @param args the command line arguments. * @throws Exception */ public static void main(String[] args) throws Exception { Options opts = new Options(); CmdLineParser parser = new CmdLineParser(InstancesSplitter.class, opts); String[] data = null; try { parser.parse(args); data = opts.crossValidationMode.split(":"); if (data.length < 2) { throw new IllegalArgumentException(); } } catch (IllegalArgumentException e) { parser.printUsage(); System.exit(1); } Random.getInstance().setSeed(opts.seed); Instances instances = InstancesReader.read(opts.attPath, opts.inputPath); File attFile = new File(opts.attPath); String prefix = attFile.getName().split("\\.")[0]; File dir = new File(opts.outputDirPath); if (!dir.exists()) { dir.mkdirs(); } switch (data[0]) { case "c": int k = Integer.parseInt(data[1]); if (data.length == 2) { Instances[][] folds = InstancesSplitter.createCrossValidationFolds(instances, opts.attToStrafity, k); for (int i = 0; i < folds.length; i++) { String path = opts.outputDirPath + File.separator + "cv." + i; File directory = new File(path); if (!directory.exists()) { directory.mkdir(); } InstancesWriter.write(folds[i][0], path + File.separator + prefix + ".attr", path + File.separator + prefix + ".train.all"); InstancesWriter.write(folds[i][1], path + File.separator + prefix + ".test"); } } else { double ratio = Double.parseDouble(data[2]); Instances[][] folds = InstancesSplitter.createCrossValidationFolds(instances, opts.attToStrafity, k, ratio); for (int i = 0; i < folds.length; i++) { String path = opts.outputDirPath + File.separator + "cv." + i; File directory = new File(path); if (!directory.exists()) { directory.mkdir(); } InstancesWriter.write(folds[i][0], path + File.separator + prefix + ".attr", path + File.separator + prefix + ".train"); InstancesWriter.write(folds[i][1], path + File.separator + prefix + ".valid"); InstancesWriter.write(folds[i][2], path + File.separator + prefix + ".test"); } } break; case "s": if (data.length == 2) { double ratio = Double.parseDouble(data[1]); Instances[] datasets = InstancesSplitter.split(instances, opts.attToStrafity, ratio); InstancesWriter.write(datasets[0], opts.outputDirPath + File.separator + prefix + ".attr", opts.outputDirPath + File.separator + prefix + ".train"); InstancesWriter.write(datasets[1], opts.outputDirPath + File.separator + prefix + ".valid"); } else if (data.length == 3) { double ratioTrain = Double.parseDouble(data[1]); double ratioValid = Double.parseDouble(data[2]); double[] ratios = new double[] { ratioTrain, ratioValid }; Instances[] datasets = InstancesSplitter.split(instances, opts.attToStrafity, ratios); InstancesWriter.write(datasets[0], opts.outputDirPath + File.separator + prefix + ".attr", opts.outputDirPath + File.separator + prefix + ".train"); InstancesWriter.write(datasets[1], opts.outputDirPath + File.separator + prefix + ".valid"); } else if (data.length == 4) { double ratioTrain = Double.parseDouble(data[1]); double ratioValid = Double.parseDouble(data[2]); double ratioTest = Double.parseDouble(data[3]); double[] ratios = new double[] { ratioTrain, ratioValid, ratioTest }; Instances[] datasets = InstancesSplitter.split(instances, opts.attToStrafity, ratios); InstancesWriter.write(datasets[0], opts.outputDirPath + File.separator + prefix + ".attr", opts.outputDirPath + File.separator + prefix + ".train"); InstancesWriter.write(datasets[1], opts.outputDirPath + File.separator + prefix + ".valid"); InstancesWriter.write(datasets[2], opts.outputDirPath + File.separator + prefix + ".test"); } else { double[] ratios = new double[data.length - 1]; for (int i = 0; i < ratios.length; i++) { ratios[i] = Double.parseDouble(data[i + 1]); } Instances[] datasets = InstancesSplitter.split(instances, opts.attToStrafity, ratios); for (int i = 0; i < datasets.length; i++) { InstancesWriter.write(datasets[i], opts.outputDirPath + File.separator + prefix + ".data." + i); } } break; default: break; } } /** * Creates cross validation folds from a dataset. For each cross validation fold contains a training set and a test * set. * * @param instances the dataset. * @param k the number of cross validation folds. * @return the cross validation datasets. */ public static Instances[][] createCrossValidationFolds(Instances instances, int k) { Instances[] datasets = split(instances, k); Instances[][] folds = new Instances[k][2]; for (int i = 0; i < k; i++) { folds[i][1] = datasets[i]; folds[i][0] = new Instances(instances.getAttributes(), instances.getTargetAttribute()); for (int j = 0; j < k; j++) { if (i == j) { continue; } for (Instance instance : datasets[j]) { folds[i][0].add(instance); } } } return folds; } /** * Creates cross validation folds from a dataset. For each cross validation fold contains a training set, a * validation set and a test set. * * @param instances the dataset. * @param k the number of cross validation folds. * @param ratio the ratio that controls how many points in the training set for each fold. * @return the cross validation datasets. */ public static Instances[][] createCrossValidationFolds(Instances instances, int k, double ratio) { Instances[] datasets = split(instances, k); Instances[][] folds = new Instances[k][3]; for (int i = 0; i < k; i++) { folds[i][2] = datasets[i]; Instances trainSet = new Instances(instances.getAttributes(), instances.getTargetAttribute()); for (int j = 0; j < k; j++) { if (i == j) { continue; } for (Instance instance : datasets[j]) { trainSet.add(instance); } } Instances[] tmp = split(trainSet, ratio); folds[i][0] = tmp[0]; folds[i][1] = tmp[1]; } return folds; } /** * Creates cross validation folds from a dataset. For each cross validation fold contains a training set and a test * set. * * @param instances the dataset. * @param attToStratify the attribute to perform stratified sampling. * @param k the number of cross validation folds. * @return the cross validation datasets. */ public static Instances[][] createCrossValidationFolds(Instances instances, String attToStratify, int k) { Instances[] datasets = split(instances, attToStratify, k); Instances[][] folds = new Instances[k][2]; for (int i = 0; i < k; i++) { folds[i][1] = datasets[i]; folds[i][0] = new Instances(instances.getAttributes(), instances.getTargetAttribute()); for (int j = 0; j < k; j++) { if (i == j) { continue; } for (Instance instance : datasets[j]) { folds[i][0].add(instance); } } } return folds; } /** * Creates cross validation folds from a dataset. For each cross validation fold contains a training set, a * validation set and a test set. * * @param instances the dataset. * @param attToStratify the attribute to perform stratified sampling. * @param k the number of cross validation folds. * @param ratio the ratio that controls how many points in the training set for each fold. * @return the cross validation datasets. */ public static Instances[][] createCrossValidationFolds(Instances instances, String attToStratify, int k, double ratio) { Instances[] datasets = split(instances, attToStratify, k); Instances[][] folds = new Instances[k][3]; for (int i = 0; i < k; i++) { folds[i][2] = datasets[i]; Instances trainSet = new Instances(instances.getAttributes(), instances.getTargetAttribute()); for (int j = 0; j < k; j++) { if (i == j) { continue; } for (Instance instance : datasets[j]) { trainSet.add(instance); } } Instances[] tmp = split(trainSet, attToStratify, ratio); folds[i][0] = tmp[0]; folds[i][1] = tmp[1]; } return folds; } /** * Splits the dataset according to the ratios. This method returns multiple instances objects, the size of each * partition is determined by the ratios array. The sum of ratios can be smaller than 1. * * @param instances the dataset. * @param ratios the ratios. * @return partitions of the dataset. */ public static Instances[] split(Instances instances, double... ratios) { if (StatUtils.sum(ratios) > 1) { throw new IllegalArgumentException("Sum of ratios is larger than 1"); } Instances dataset = new Instances(instances); dataset.shuffle(Random.getInstance().getRandom()); Instances[] datasets = new Instances[ratios.length]; for (int i = 0; i < datasets.length; i++) { datasets[i] = new Instances(dataset.getAttributes(), dataset.getTargetAttribute()); } double sumRatios = StatUtils.sum(ratios); int n = 0; for (int k = 0; k < datasets.length; k++) { int m = (int) (dataset.size() * ratios[k]); if (k == datasets.length - 1 && MathUtils.equals(sumRatios, 1.0)) { m = dataset.size() - n; } Instances partition = datasets[k]; for (int i = n; i < n + m; i++) { partition.add(dataset.get(i)); } n += m; } return datasets; } /** * Splits the dataset according to the ratio. This method returns two instances objects, the size of the first one * is 100% * ratio of the orignal dataset while the size of the second one is 100% * (1 - ratio) of the orignal * dataset. * * @param instances the dataset. * @param ratio the ratio. * @return two smaller datasets. */ public static Instances[] split(Instances instances, double ratio) { return split(instances, new double[] { ratio, 1 - ratio }); } /** * Splits the dataset into k equi-sized datasets. * * @param instances the dataset. * @param k the number of datasets to return. * @return k equi-sized datasets. */ public static Instances[] split(Instances instances, int k) { Instances dataset = new Instances(instances); dataset.shuffle(Random.getInstance().getRandom()); Instances[] datasets = new Instances[k]; for (int i = 0; i < datasets.length; i++) { datasets[i] = new Instances(dataset.getAttributes(), dataset.getTargetAttribute()); } for (int i = 0; i < dataset.size(); i++) { datasets[i % datasets.length].add(dataset.get(i)); } return datasets; } /** * Splits the dataset according to the ratio. This method returns two instances objects, the size of the first one * is 100% * ratio of the orignal dataset while the size of the second one is 100% * (1 - ratio) of the orignal * dataset. * * @param instances the dataset. * @param attToStratify the attribute to perform stratified sampling. * @param ratio the ratio. * @return two smaller datasets. */ public static Instances[] split(Instances instances, String attToStratify, double ratio) { return split(instances, attToStratify, new double[] { ratio, 1 - ratio }); } /** * Splits the dataset according to the ratios. This method returns multiple instances objects, the size of each * partition is determined by the ratios array. The sum of ratios can be smaller than 1. * * @param instances the dataset. * @param attToStratify the attribute to perform stratified sampling. * @param ratios the ratios. * @return partitions of the dataset. */ public static Instances[] split(Instances instances, String attToStratify, double... ratios) { if (attToStratify == null) { return split(instances, ratios); } List> strata = getStrata(instances, attToStratify); if (strata == null) { return split(instances, ratios); } Instances[] datasets = new Instances[ratios.length]; for (int i = 0; i < datasets.length; i++) { datasets[i] = new Instances(instances.getAttributes(), instances.getTargetAttribute()); } double sumRatios = StatUtils.sum(ratios); for (List list : strata) { int n = 0; for (int k = 0; k < datasets.length; k++) { int m = (int) (list.size() * ratios[k]); if (k == datasets.length -1 && MathUtils.equals(sumRatios, 1.0)) { m = list.size() - n; } Instances partition = datasets[k]; for (int i = n; i < n + m; i++) { partition.add(list.get(i)); } n += m; } } return datasets; } /** * Splits the dataset into k equi-sized datasets. * * @param instances the dataset. * @param attToStratify the attribute to perform stratified sampling. * @param k the number of datasets to return. * @return k equi-sized datasets. */ public static Instances[] split(Instances instances, String attToStratify, int k) { if (attToStratify == null) { return split(instances, k); } List> strata = getStrata(instances, attToStratify); if (strata == null) { return split(instances, k); } Instances[] datasets = new Instances[k]; for (int i = 0; i < datasets.length; i++) { datasets[i] = new Instances(instances.getAttributes(), instances.getTargetAttribute()); } for (List stratum : strata) { for (int i = 0; i < stratum.size(); i++) { datasets[i % datasets.length].add(stratum.get(i)); } } return datasets; } private static List> getStrata(Instances instances, String attToStratify) { List> lists = new ArrayList<>(); Instances dataset = new Instances(instances); dataset.shuffle(Random.getInstance().getRandom()); if (instances.getTargetAttribute().getName().equals(attToStratify)) { Attribute targetAtt = instances.getTargetAttribute(); if (targetAtt == null || targetAtt.getType() == Type.NUMERIC) { return null; } int cardinality = 0; if (targetAtt.getType() == Type.BINNED) { cardinality = ((BinnedAttribute) targetAtt).getNumBins(); } else { cardinality = ((NominalAttribute) targetAtt).getCardinality(); } for (int i = 0; i < cardinality; i++) { lists.add(new ArrayList()); } for (Instance instance : instances) { int idx = (int) instance.getTarget(); lists.get(idx).add(instance); } } else { List attributes = instances.getAttributes(); Attribute attr = null; for (Attribute att : attributes) { if (att.getName().equals(attToStratify)) { attr = att; break; } } if (attr == null || attr.getType() == Type.NUMERIC) { return null; } int cardinality = 0; if (attr.getType() == Type.BINNED) { cardinality = ((BinnedAttribute) attr).getNumBins(); } else { cardinality = ((NominalAttribute) attr).getCardinality(); } for (int i = 0; i < cardinality; i++) { lists.add(new ArrayList()); } for (Instance instance : dataset) { int idx = (int) instance.getValue(attr); lists.get(idx).add(instance); } } return lists; } } ================================================ FILE: src/main/java/mltk/core/processor/OneHotEncoder.java ================================================ package mltk.core.processor; import java.util.ArrayList; import java.util.List; import mltk.core.Attribute; import mltk.core.BinnedAttribute; import mltk.core.Instance; import mltk.core.Instances; import mltk.core.NominalAttribute; import mltk.core.NumericalAttribute; /** * Class for one-hot encoders. Binned attributes and nominal attributes are * transformed into a set of binary attributes using one-hot encoding. * * @author Yin Lou * */ public class OneHotEncoder { /** * Transforms all binned and nominal attributes to binary attributes using * one-hot encoding. * * @param instances the input instances. * @return the transformed instances. */ public Instances process(Instances instances) { List attrListOld = instances.getAttributes(); List attrListNew = new ArrayList<>(); int[] offset = new int[instances.dimension()]; boolean[] isNumerical = new boolean[instances.dimension()]; int attIndex = 0; for (int j = 0; j < attrListOld.size(); j++) { Attribute attribute = attrListOld.get(j); offset[j] = attIndex; String name = attribute.getName(); if (attribute instanceof BinnedAttribute) { BinnedAttribute binnedAttribute = (BinnedAttribute) attribute; int size = binnedAttribute.getNumBins(); for (int k = 0; k < size; k++) { NumericalAttribute attr = new NumericalAttribute(name + "_" + k); attr.setIndex(attIndex++); attrListNew.add(attr); } } else if (attribute instanceof NominalAttribute) { NominalAttribute nominalAttribute = (NominalAttribute) attribute; String[] states = nominalAttribute.getStates(); for (String state : states) { NumericalAttribute attr = new NumericalAttribute(name + "_" + state); attr.setIndex(attIndex++); attrListNew.add(attr); } } else { NumericalAttribute attr = new NumericalAttribute(name); attr.setIndex(attIndex++); attrListNew.add(attr); isNumerical[j] = true; } } Instances instancesNew = new Instances(attrListNew, instances.getTargetAttribute(), instances.size()); for (Instance instance : instances) { int[] indices = new int[instances.dimension()]; double[] values = new double[instances.dimension()]; for (int j = 0; j < attrListOld.size(); j++) { if (isNumerical[j]) { indices[j] = offset[j]; values[j] = instance.getValue(attrListOld.get(j)); } else { int v = (int) instance.getValue(attrListOld.get(j)); indices[j] = offset[j] + v; values[j] = 1.0; } } Instance instanceNew = new Instance(indices, values, instance.getTarget(), instance.getWeight()); instancesNew.add(instanceNew); } return instancesNew; } } ================================================ FILE: src/main/java/mltk/core/processor/package-info.java ================================================ /** * Provides classes for processing datasets. */ package mltk.core.processor; ================================================ FILE: src/main/java/mltk/feature/selection/BackwardElimination.java ================================================ package mltk.feature.selection; import java.util.ArrayList; import java.util.List; import mltk.core.Attribute; import mltk.core.Instances; import mltk.predictor.BaggedEnsembleLearner; import mltk.predictor.Regressor; import mltk.predictor.evaluation.Evaluator; import mltk.predictor.tree.ensemble.ag.AdditiveGrovesLearner; import mltk.util.StatUtils; import mltk.util.tuple.DoublePair; import mltk.util.tuple.Pair; /** * Class for feature selection using backward elimination. * * @author Yin Lou * */ public class BackwardElimination { /** * Selects features using backward elimination. * * @param trainSet the training set. * @param validSet the validation set. * @param learner the learner to use. * @param numIters the number of iterations to estimate the mean and std for full complexity models. * @return the list of selected features and <mean, std$gt; pair for full complexity models. */ public static Pair, DoublePair> select(Instances trainSet, Instances validSet, BaggedEnsembleLearner learner, int numIters) { List attributes = trainSet.getAttributes(); List selected = new ArrayList<>(attributes); DoublePair perf = null; for (;;) { if (selected.size() == 0) { break; } boolean changed = false; trainSet.setAttributes(selected); perf = evaluateModel(trainSet, validSet, learner, numIters); System.out.println("Mean: " + perf.v1 + " Std: " + perf.v2); int i; for (i = 0; i < selected.size();) { List attList = new ArrayList<>(selected); Attribute attr = attList.get(i); attList.remove(i); trainSet.setAttributes(attList); Regressor regressor = (Regressor) learner.build(trainSet); double rmse = Evaluator.evalRMSE(regressor, validSet); System.out.println("Testing: " + attr.getName() + " RMSE: " + rmse); if (perf.v1 - perf.v2 * 3 <= rmse && rmse <= perf.v1 + perf.v2 * 3) { // Eliminate feature selected.remove(i); changed = true; System.out.println("Eliminate: " + attr.getName()); } else { i++; } } if (!changed) { break; } } trainSet.setAttributes(attributes); return new Pair, DoublePair>(selected, perf); } /** * Selects features using backward elimination in Additive Groves. * * @param trainSet the training set. * @param validSet the validation set. * @param learner the learner to use. * @param baggingIters the number of bagging iterations. * @param numTrees the number of trees in a grove. * @param alpha the alpha. * @param numIters the number of iterations to estimate the mean and std for full complexity models. * @return the list of selected features and $lt;mean, std$gt; pair for full complexity models. */ public static Pair, DoublePair> select(Instances trainSet, Instances validSet, AdditiveGrovesLearner learner, int baggingIters, int numTrees, double alpha, int numIters) { List attributes = trainSet.getAttributes(); List selected = new ArrayList<>(attributes); DoublePair perf = null; for (;;) { if (selected.size() == 0) { break; } boolean changed = false; trainSet.setAttributes(selected); perf = evaluateModel(trainSet, validSet, learner, baggingIters, numTrees, alpha, numIters); System.out.println("Mean: " + perf.v1 + " Std: " + perf.v2); int i; for (i = 0; i < selected.size();) { List attList = new ArrayList<>(selected); Attribute attr = attList.get(i); System.out.println("Testing: " + attr.getName()); attList.remove(i); trainSet.setAttributes(attList); Regressor regressor = learner.runLayeredTraining(trainSet,baggingIters, numTrees, alpha); double rmse = Evaluator.evalRMSE(regressor, validSet); System.out.println("Testing: " + attr.getName() + " RMSE: " + rmse); if (perf.v1 - perf.v2 * 3 <= rmse && rmse <= perf.v1 + perf.v2 * 3) { // Eliminate feature selected.remove(i); changed = true; System.out.println("Eliminate: " + attr.getName()); } else { i++; } } if (!changed) { break; } } trainSet.setAttributes(attributes); return new Pair, DoublePair>(selected, perf); } private static DoublePair evaluateModel(Instances trainSet, Instances validSet, BaggedEnsembleLearner learner, int numIters) { // Estimating std of full complexity model double[] rmse = new double[numIters]; for (int i = 0; i < rmse.length; i++) { Regressor regressor = (Regressor) learner.build(trainSet); rmse[i] = Evaluator.evalRMSE(regressor, validSet); } double mean = StatUtils.mean(rmse); double std = StatUtils.sd(rmse); return new DoublePair(mean, std); } private static DoublePair evaluateModel(Instances trainSet, Instances validSet, AdditiveGrovesLearner learner, int baggingIters, int numTrees, double alpha, int numIters) { // Estimating std of full complexity model double[] rmse = new double[numIters]; for (int i = 0; i < rmse.length; i++) { Regressor regressor = learner.runLayeredTraining(trainSet, baggingIters, numTrees, alpha); rmse[i] = Evaluator.evalRMSE(regressor, validSet); System.out.println("\tEvaluating model " + (i + 1) + " / " + numIters + "\t" + rmse[i]); } double mean = StatUtils.mean(rmse); double std = StatUtils.sd(rmse); return new DoublePair(mean, std); } } ================================================ FILE: src/main/java/mltk/feature/selection/package-info.java ================================================ /** * Contains classes for feature selection. */ package mltk.feature.selection; ================================================ FILE: src/main/java/mltk/predictor/BaggedEnsemble.java ================================================ package mltk.predictor; import java.util.HashMap; import java.util.Map; import mltk.core.Instance; /** * Class for bagged ensembles. * * @author Yin Lou * */ public class BaggedEnsemble extends Ensemble { /** * Constructor. */ public BaggedEnsemble() { super(); } /** * Constructor. * * @param capacity the capacity of this bagged ensemble. */ public BaggedEnsemble(int capacity) { super(capacity); } @Override public double regress(Instance instance) { if (predictors.size() == 0) { return 0.0; } else { double prediction = 0.0; for (Predictor predictor : predictors) { Regressor regressor = (Regressor) predictor; prediction += regressor.regress(instance); } return prediction / predictors.size(); } } @Override public int classify(Instance instance) { if (predictors.size() == 0) { // Default: return first class return 0; } else { Map votes = new HashMap<>(); for (Predictor predictor : predictors) { Classifier classifier = (Classifier) predictor; int cls = (int) classifier.classify(instance); if (!votes.containsKey(cls)) { votes.put(cls, 0); } votes.put(cls, votes.get(cls) + 1); } int prediction = 0; int maxVotes = 0; for (int cls : votes.keySet()) { int numVotes = votes.get(cls); if (numVotes > maxVotes) { maxVotes = numVotes; prediction = cls; } } return prediction; } } @Override public BaggedEnsemble copy() { BaggedEnsemble copy = new BaggedEnsemble(predictors.size()); for (Predictor predictor : predictors) { copy.add(predictor.copy()); } return copy; } } ================================================ FILE: src/main/java/mltk/predictor/BaggedEnsembleLearner.java ================================================ package mltk.predictor; import mltk.core.Instances; import mltk.core.Sampling; /** * Class for learning bagged ensembles. * * @author Yin Lou * */ public class BaggedEnsembleLearner extends Learner { protected int baggingIters; protected Learner learner; protected Instances[] bags; /** * Constructor. * * @param baggingIters the number of bagging iterations. * @param learner the learner. */ public BaggedEnsembleLearner(int baggingIters, Learner learner) { this.baggingIters = baggingIters; this.learner = learner; } /** * Returns the number of bagging iterations. * * @return the number of bagging iterations. */ public int getBaggingIterations() { return baggingIters; } /** * Sets the number of bagging iterations. * * @param baggingIters the number of bagging iterations. */ public void setBaggingIterations(int baggingIters) { this.baggingIters = baggingIters; } /** * Returns the learner. * * @return the learner. */ public Learner getLearner() { return learner; } /** * Sets the learner. * * @param learner the learner. */ public void setLearner(Learner learner) { this.learner = learner; } /** * Returns the bootstrap samples. * * @return the bootstrap samples. */ public Instances[] getBags() { return bags; } /** * Sets the bootstrap samples. * * @param bags the bootstrap samples. */ public void setBags(Instances[] bags) { this.bags = bags; } @Override public BaggedEnsemble build(Instances instances) { // Create bags bags = Sampling.createBags(instances, baggingIters); BaggedEnsemble baggedEnsemble = new BaggedEnsemble(bags.length); for (Instances bag : bags) { baggedEnsemble.add(learner.build(bag)); } return baggedEnsemble; } /** * Builds a bagged ensemble. * * @param bags the bootstrap samples. * @return a bagged ensemble. */ public BaggedEnsemble build(Instances[] bags) { BaggedEnsemble baggedEnsemble = new BaggedEnsemble(bags.length); for (Instances bag : bags) { baggedEnsemble.add(learner.build(bag)); } return baggedEnsemble; } } ================================================ FILE: src/main/java/mltk/predictor/BoostedEnsemble.java ================================================ package mltk.predictor; import mltk.core.Instance; /** * Class for boosted ensembles. * * @author Yin Lou * */ public class BoostedEnsemble extends Ensemble { /** * Constructor. */ public BoostedEnsemble() { super(); } /** * Constructor. * * @param capacity the capacity of the boosted ensemble. */ public BoostedEnsemble(int capacity) { super(capacity); } @Override public double regress(Instance instance) { double prediction = 0.0; for (Predictor predictor : predictors) { Regressor regressor = (Regressor) predictor; prediction += regressor.regress(instance); } return prediction; } @Override public int classify(Instance instance) { double pred = regress(instance); if (pred >= 0) { return 1; } else { return -1; } } /** * Removes a particular predictor. * * @param index the index of the predictor to remove. */ public void remove(int index) { predictors.remove(index); } /** * Removes the last predictor. */ public void removeLast() { if (predictors.size() > 0) { predictors.remove(predictors.size() - 1); } } @Override public BoostedEnsemble copy() { BoostedEnsemble copy = new BoostedEnsemble(predictors.size()); for (Predictor predictor : predictors) { copy.add(predictor.copy()); } return copy; } } ================================================ FILE: src/main/java/mltk/predictor/Classifier.java ================================================ package mltk.predictor; import mltk.core.Instance; /** * Interface for classfiers. * * @author Yin Lou * */ public interface Classifier extends Predictor { /** * Classifies an instance. * * @param instance the instance to classify. * @return a classified value. */ public int classify(Instance instance); } ================================================ FILE: src/main/java/mltk/predictor/Ensemble.java ================================================ package mltk.predictor; import java.io.BufferedReader; import java.io.PrintWriter; import java.util.ArrayList; import java.util.List; /** * Abstract class for ensembles. * * @author Yin Lou * */ public abstract class Ensemble implements Classifier, Regressor { protected List predictors; /** * Constructor. */ public Ensemble() { predictors = new ArrayList<>(); } /** * Constructor. * * @param capacity the capacity of this ensemble. */ public Ensemble(int capacity) { predictors = new ArrayList<>(capacity); } /** * Returns a particular predictor. * * @param index the index of predictor. * @return a particular predictor. */ public Predictor get(int index) { return predictors.get(index); } /** * Returns the internal predictors. * * @return the internal predictors. */ public List getPredictors() { return predictors; } /** * Adds a new predictor to the ensemble. * * @param predictor the new predictor. */ public void add(Predictor predictor) { predictors.add(predictor); } /** * Returns the size of this ensemble. * * @return the size of this ensemble. */ public int size() { return predictors.size(); } /** * Clears this ensemble. */ public void clear() { predictors.clear(); } @Override public void read(BufferedReader in) throws Exception { int capacity = Integer.parseInt(in.readLine().split(": ")[1]); predictors = new ArrayList<>(capacity); in.readLine(); for (int i = 0; i < capacity; i++) { String line = in.readLine(); String predictorName = line.substring(1, line.length() - 1).split(": ")[1]; Class clazz = Class.forName(predictorName); Predictor predictor = (Predictor) clazz.getDeclaredConstructor().newInstance(); predictor.read(in); predictors.add(predictor); in.readLine(); } } @Override public void write(PrintWriter out) throws Exception { out.printf("[Predictor: %s]\n", this.getClass().getCanonicalName()); out.println("Ensemble: " + predictors.size()); out.println(); for (Predictor predictor : predictors) { predictor.write(out); out.println(); } } } ================================================ FILE: src/main/java/mltk/predictor/Family.java ================================================ package mltk.predictor; /** * Class for response distribution family. This class is used for GLMs/GAMs. * * @author Yin Lou * */ public enum Family { GAUSSIAN("gaussian", LinkFunction.IDENTITY), BINOMIAL("binomial", LinkFunction.LOGIT); /** * Parses an enumeration from a string. * * @param name the family name. * @return a parsed distribution. */ public static Family get(String name) { for (Family family : Family.values()) { if (name.startsWith(family.name)) { return family; } } throw new IllegalArgumentException("Invalid family name: " + name); } String name; LinkFunction link; Family(String name, LinkFunction link) { this.name = name; this.link = link; } /** * Returns the default link function for this family. * * @return the default link function for this family. */ public LinkFunction getDefaultLinkFunction() { return link; } /** * Returns the string representation of this family with default link function. */ public String toString() { return name + "(" + link + ")"; } } ================================================ FILE: src/main/java/mltk/predictor/HoldoutValidatedLearner.java ================================================ package mltk.predictor; import mltk.core.Instances; import mltk.predictor.evaluation.ConvergenceTester; import mltk.predictor.evaluation.Metric; /** * Class for holdout validated learners. * * @author Yin Lou * */ public abstract class HoldoutValidatedLearner extends Learner { protected Instances validSet; protected Metric metric; protected ConvergenceTester ct; /** * Constructor. */ public HoldoutValidatedLearner() { ct = new ConvergenceTester(-1, 0, 1.0); } /** * Returns the validation set. * * @return the validation set. */ public Instances getValidSet() { return validSet; } /** * Sets the validation set. * * @param validSet the validation set. */ public void setValidSet(Instances validSet) { this.validSet = validSet; } /** * Returns the metric. * * @return the metric. */ public Metric getMetric() { return metric; } /** * Sets the metric. * * @param metric the metric. */ public void setMetric(Metric metric) { this.metric = metric; } /** * Returns the convergence tester. * * @return the convergence tester. */ public ConvergenceTester getConvergenceTester() { return ct; } /** * Sets the convergence tester. * * @param ct the convergence tester to set. */ public void setConvergenceTester(ConvergenceTester ct) { this.ct = ct; } } ================================================ FILE: src/main/java/mltk/predictor/Learner.java ================================================ package mltk.predictor; import java.util.ArrayList; import java.util.List; import java.util.Map; import java.util.TreeMap; import mltk.core.Attribute; import mltk.core.Instance; import mltk.core.Instances; import mltk.core.SparseVector; import mltk.predictor.evaluation.Error; import mltk.predictor.evaluation.Metric; import mltk.predictor.evaluation.RMSE; import mltk.util.MathUtils; import mltk.util.StatUtils; import mltk.util.VectorUtils; import mltk.util.tuple.IntDoublePair; /** * Class for learners. * * @author Yin Lou * */ public abstract class Learner { protected boolean verbose; /** * Returns {@code true} if we output something during the training. * * @return {@code true} if we output something during the training. */ public boolean isVerbose() { return verbose; } /** * Sets whether we output something during the training. * * @param verbose the switch if we output things during training. */ public void setVerbose(boolean verbose) { this.verbose = verbose; } /** * Enumeration of learning tasks. * */ public enum Task { /** * Classification task. */ CLASSIFICATION("classification"), /** * Regression task. */ REGRESSION("regression"); String task; Task(String task) { this.task = task; } /** * Returns the string representation of learning tasks. */ public String toString() { return task; } /** * Parses a task from a string. * * @param name the name of the task. * @return a parsed task. */ public static Task get(String name) { for (Task task : Task.values()) { if (task.task.startsWith(name)) { return task; } } throw new IllegalArgumentException("Invalid task name: " + name); } /** * Returns the default metric for this task. * * @return the default metric for this task. */ public Metric getDefaultMetric() { Metric metric = null; switch (this) { case CLASSIFICATION: metric = new Error(); break; case REGRESSION: metric = new RMSE(); break; default: break; } return metric; } } /** * Builds a predictor from training set. * * @param instances the training set. * @return a predictior. */ public abstract Predictor build(Instances instances); /** * Returns {@code true} if the instances are treated as sparse. * * @param instances the instances to test. * @return {@code true} if the instances are treated as sparse. */ protected boolean isSparse(Instances instances) { int numSparseInstances = 0; for (Instance instance : instances) { if (instance.isSparse()) { numSparseInstances++; } } return numSparseInstances > instances.size() / 2; } /** * Returns the column-oriented format of sparse dataset. This method automatically removes attributes with * close-to-zero variance. * * @param instances the instances. * @param normalize {@code true} if all the columns are normalized. * @return the column-oriented format of sparse dataset. */ protected SparseDataset getSparseDataset(Instances instances, boolean normalize) { List attributes = instances.getAttributes(); int maxAttrId = attributes.get(attributes.size() - 1).getIndex(); boolean[] included = new boolean[maxAttrId + 1]; for (Attribute attribute : attributes) { included[attribute.getIndex()] = true; } final int n = instances.size(); Map> map = new TreeMap<>(); double[] y = new double[n]; for (int i = 0; i < instances.size(); i++) { Instance instance = instances.get(i); SparseVector vector = (SparseVector) instance.getVector(); int[] indices = vector.getIndices(); double[] values = vector.getValues(); for (int j = 0; j < indices.length; j++) { if (included[indices[j]]) { if (!map.containsKey(indices[j])) { map.put(indices[j], new ArrayList()); } List list = map.get(indices[j]); list.add(new IntDoublePair(i, values[j])); } } y[i] = instance.getTarget(); } List attrsList = new ArrayList<>(map.size()); List indicesList = new ArrayList<>(map.size()); List valuesList = new ArrayList<>(map.size()); List stdList = new ArrayList<>(map.size()); List cList = null; if (normalize) { cList = new ArrayList<>(); } double factor = Math.sqrt(n); for (Map.Entry> entry : map.entrySet()) { Integer attr = entry.getKey(); List list = entry.getValue(); int[] indices = new int[list.size()]; double[] values = new double[list.size()]; for (int i = 0; i < list.size(); i++) { IntDoublePair pair = list.get(i); indices[i] = pair.v1; values[i] = pair.v2; } double std = StatUtils.sd(values, n); if (std > MathUtils.EPSILON) { attrsList.add(attr); indicesList.add(indices); valuesList.add(values); stdList.add(std); if (normalize) { // Normalize the data double c = factor / std; VectorUtils.multiply(values, c); cList.add(c); } } } final int p = attrsList.size(); int[] attrs = new int[p]; int[][] indices = new int[p][]; double[][] values = new double[p][]; for (int j = 0; j < p; j++) { attrs[j] = attrsList.get(j); indices[j] = indicesList.get(j); values[j] = valuesList.get(j); } double[] std = new double[stdList.size()]; for (int j = 0; j < std.length; j++) { std[j] = stdList.get(j); } double[] c = null; if (cList != null) { c = new double[cList.size()]; for (int j = 0; j < c.length; j++) { c[j] = cList.get(j); } } return new SparseDataset(attrs, indices, values, y, std, c); } /** * Returns the column-oriented format of dense dataset. This method automatically removes attributes with * close-to-zero variance. * * @param instances the instances. * @param normalize {@code true} if all the columns are normalized. * @return the column-oriented format of dense dataset. */ protected DenseDataset getDenseDataset(Instances instances, boolean normalize) { List attributes = instances.getAttributes(); final int p = instances.dimension(); final int n = instances.size(); // Convert to column oriented format List xList = new ArrayList<>(p); double[] y = new double[n]; for (int i = 0; i < n; i++) { y[i] = instances.get(i).getTarget(); } List attrsList = new ArrayList<>(p); List stdList = new ArrayList<>(p); List cList = null; if (normalize) { cList = new ArrayList<>(); } double factor = Math.sqrt(n); for (int j = 0; j < p; j++) { int attIndex = attributes.get(j).getIndex(); double[] x = new double[n]; for (int i = 0; i < n; i++) { x[i] = instances.get(i).getValue(attIndex); } double std = StatUtils.sd(x); if (std > MathUtils.EPSILON) { attrsList.add(attIndex); xList.add(x); stdList.add(std); if (normalize) { // Normalize the data double c = factor / std; VectorUtils.multiply(x, c); cList.add(c); } } } int[] attrs = new int[attrsList.size()]; double[][] x = new double[attrsList.size()][]; for (int j = 0; j < attrs.length; j++) { attrs[j] = attrsList.get(j); x[j] = xList.get(j); } double[] std = new double[stdList.size()]; for (int j = 0; j < std.length; j++) { std[j] = stdList.get(j); } double[] c = null; if (cList != null) { c = new double[cList.size()]; for (int j = 0; j < c.length; j++) { c[j] = cList.get(j); } } return new DenseDataset(attrs, x, y, std, c); } /** * Class for sparse dataset. * */ protected class SparseDataset { public int[] attrs; public int[][] indices; public double[][] values; public double[] y; public double[] stdList; public double[] cList; SparseDataset(int[] attrs, int[][] indices, double[][] values, double[] y, double[] stdList, double[] cList) { this.attrs = attrs; this.indices = indices; this.values = values; this.y = y; this.stdList = stdList; this.cList = cList; } } /** * Class for dense dataset. * */ protected class DenseDataset { public int[] attrs; public double[][] x; public double[] y; public double[] stdList; public double[] cList; DenseDataset(int[] attrs, double[][] x, double[] y, double[] stdList, double[] cList) { this.attrs = attrs; this.x = x; this.y = y; this.stdList = stdList; this.cList = cList; } } } ================================================ FILE: src/main/java/mltk/predictor/LinkFunction.java ================================================ package mltk.predictor; import mltk.util.MathUtils; /** * Class for link functions. * * @author Yin Lou * */ public enum LinkFunction { IDENTITY("identity"), LOGIT("logit"); /** * Parses a link function from a string. * * @param name the name of the function. * @return a parsed link function. */ public static LinkFunction get(String name) { for (LinkFunction link : LinkFunction.values()) { if (link.name.startsWith(name)) { return link; } } throw new IllegalArgumentException("Unknown link function: " + name); } String name; LinkFunction(String name) { this.name = name; } /** * Applies the inverse of this link function. * * @param x the argument. * @return the inverse of this link function. */ public double applyInverse(double x) { double r = 0; switch (this) { case IDENTITY: r = x; break; case LOGIT: r = MathUtils.sigmoid(x); break; default: break; } return r; } /** * Returns the string representation of this link function. */ public String toString() { return name; } } ================================================ FILE: src/main/java/mltk/predictor/Predictor.java ================================================ package mltk.predictor; import java.io.BufferedReader; import java.io.PrintWriter; import mltk.core.Copyable; import mltk.core.Writable; /** * Interface for predictors. * * @author Yin Lou * */ public interface Predictor extends Writable, Copyable { /** * Reads in this predictor. This method is used in {@link mltk.predictor.io.PredictorReader}. * * @param in the reader. * @throws Exception */ public void read(BufferedReader in) throws Exception; /** * Writes this predictor. This method is used in {@link mltk.predictor.io.PredictorWriter}. * * @param out the writer. * @throws Exception */ public void write(PrintWriter out) throws Exception; } ================================================ FILE: src/main/java/mltk/predictor/ProbabilisticClassifier.java ================================================ package mltk.predictor; import mltk.core.Instance; /** * Interface for classifiers that predicts the class probabilities. * * @author Yin Lou * */ public interface ProbabilisticClassifier extends Classifier { /** * Returns the class probabilities. * * @param instance the instance to predict. * @return the class probabilities. */ public double[] predictProbabilities(Instance instance); } ================================================ FILE: src/main/java/mltk/predictor/Regressor.java ================================================ package mltk.predictor; import mltk.core.Instance; /** * Interface for regressors. * * @author Yin Lou * */ public interface Regressor extends Predictor { /** * Regresses an instance. * * @param instance the instance to regress. * @return a regressed value. */ public double regress(Instance instance); } ================================================ FILE: src/main/java/mltk/predictor/evaluation/AUC.java ================================================ package mltk.predictor.evaluation; import java.util.Arrays; import java.util.Comparator; import mltk.core.Instances; import mltk.util.tuple.DoublePair; /** * Class for evaluating area under ROC curve. * * @author Yin Lou * */ public class AUC extends SimpleMetric { private class DoublePairComparator implements Comparator { @Override public int compare(DoublePair o1, DoublePair o2) { int cmp = Double.compare(o1.v1, o2.v1); if (cmp == 0) { cmp = Double.compare(o1.v2, o2.v2); } return cmp; } } /** * Constructor. */ public AUC() { super(true); } @Override public double eval(double[] preds, double[] targets) { DoublePair[] a = new DoublePair[preds.length]; for (int i = 0; i < preds.length; i++) { a[i] = new DoublePair(preds[i], targets[i]); } return eval(a); } @Override public double eval(double[] preds, Instances instances) { DoublePair[] a = new DoublePair[preds.length]; for (int i = 0; i < preds.length; i++) { a[i] = new DoublePair(preds[i], instances.get(i).getTarget()); } return eval(a); } protected double eval(DoublePair[] a) { Arrays.sort(a, new DoublePairComparator()); double[] fraction = new double[a.length]; for (int idx = 0; idx < fraction.length;) { int begin = idx; double pos = 0; for (; idx < fraction.length && a[idx].v1 == a[begin].v1; idx++) { pos += a[idx].v2; } double frac = pos / (idx - begin); for (int i = begin; i < idx; i++) { fraction[i] = frac; } } double tt = 0; double tf = 0; double ft = 0; double ff = 0; for (int i = 0; i < a.length; i++) { tf += a[i].v2; ff += 1 - a[i].v2; } double area = 0; double tpfPrev = 0; double fpfPrev = 0; for (int i = a.length - 1; i >= 0; i--) { tt += fraction[i]; tf -= fraction[i]; ft += 1 - fraction[i]; ff -= 1 - fraction[i]; double tpf = tt / (tt + tf); double fpf = 1.0 - ff / (ft + ff); area += 0.5 * (tpf + tpfPrev) * (fpf - fpfPrev); tpfPrev = tpf; fpfPrev = fpf; } return area; } } ================================================ FILE: src/main/java/mltk/predictor/evaluation/ConvergenceTester.java ================================================ package mltk.predictor.evaluation; import java.util.ArrayList; import java.util.List; /** * Class for testing convergence given a list of metric values. * * @author Yin Lou * */ public class ConvergenceTester { protected int minNumPoints; protected int n; protected double c; protected double bestSoFar; protected int bestIdx; protected Metric metric; protected List measureList; /** * Parses the convergence criteria string. * * @param cc the convergence criteria string. * @return a convergence tester. */ public static ConvergenceTester parse(String cc) { int minNumPoints = -1; int n = 0; double c = 1.0; if (cc != null && !cc.equals("")) { String[] strs = cc.split(":"); if (strs.length > 0) { minNumPoints = Integer.parseInt(strs[0]); } if (strs.length > 1) { n = Integer.parseInt(strs[1]); } if (strs.length > 2) { c = Double.parseDouble(strs[2]); } } return new ConvergenceTester(minNumPoints, n, c); } /** * Constructor. * * @param minNumPoints the minimum number of points to be considered convergence. * @param c a constant factor in [0, 1]. */ public ConvergenceTester(int minNumPoints, double c) { this(minNumPoints, 0, c, 1000); } /** * Constructor. * * @param minNumPoints the minimum number of points to be considered convergence. * @param n the n. */ public ConvergenceTester(int minNumPoints, int n) { this(minNumPoints, n, 1.0, 1000); } /** * Constructor. * * @param minNumPoints the minimum number of points to be considered convergence. * @param n the n. * @param c a constant factor in [0, 1]. */ public ConvergenceTester(int minNumPoints, int n, double c) { this(minNumPoints, n, c, 1000); } /** * Constructor. A list of metric values is viewed as converged if the list * has at least {@code minNumPoints} and the {@code getBestIndex() + n < size() * c}. * * @param minNumPoints the minimum number of points to be considered convergence. * @param n the n. * @param c a constant factor in [0, 1]. * @param capacity the initial capacity. */ public ConvergenceTester(int minNumPoints, int n, double c, int capacity) { if (n < 0) { throw new IllegalArgumentException("n has to be non-negative."); } if (!(c >= 0 && c <= 1)) { throw new IllegalArgumentException("c should to be in [0, 1]."); } this.minNumPoints = minNumPoints; this.n = n; this.c = c; measureList = new ArrayList<>(capacity); } /** * Returns the metric. * * @return the metric. */ public Metric getMetric() { return metric; } /** * Sets the metric. This method also resets internal status of this tester. * * @param metric the metric to set. */ public void setMetric(Metric metric) { this.metric = metric; measureList.clear(); bestSoFar = metric.worstValue(); bestIdx = -1; } /** * Adds a measure. * * @param measure the metric value to add. */ public void add(double measure) { measureList.add(measure); if (metric.isFirstBetter(measure, bestSoFar)) { bestSoFar = measure; bestIdx = measureList.size() - 1; } } /** * Returns the index of best metric value so far. * * @return the index of best metric value so far. */ public int getBestIndex() { return bestIdx; } /** * Returns the best measure value so far. * * @return the best measure value so far. */ public double getBestMetricValue() { return bestSoFar; } /** * Returns the number of points. * * @return the number of points. */ public int size() { return measureList.size(); } /** * Returns the list of metric values. * * @return the list of metric values. */ public List getMeasureList() { return measureList; } /** * Returns {@code true} if the series is converged. * * @return {@code true} if the series is converged. */ public boolean isConverged() { return minNumPoints >= 0 && measureList.size() >= minNumPoints && bestIdx > 0 && bestIdx + n < measureList.size() * c; } } ================================================ FILE: src/main/java/mltk/predictor/evaluation/Error.java ================================================ package mltk.predictor.evaluation; import mltk.core.Instances; /** * Class for evaluating error rate. * * @author Yin Lou * */ public class Error extends SimpleMetric { /** * Constructor. */ public Error() { super(false); } @Override public double eval(double[] preds, double[] targets) { double error = 0; for (int i = 0; i < preds.length; i++) { // Handles both probability and predicted label double cls = preds[i] <= 0 ? 0 : 1; if (cls != targets[i]) { error++; } } return error / preds.length; } @Override public double eval(double[] preds, Instances instances) { double error = 0; for (int i = 0; i < preds.length; i++) { // Handles both probability and predicted label double cls = preds[i] <= 0 ? 0 : 1; if (cls != instances.get(i).getTarget()) { error++; } } return error / preds.length; } } ================================================ FILE: src/main/java/mltk/predictor/evaluation/Evaluator.java ================================================ package mltk.predictor.evaluation; import java.util.List; import mltk.cmdline.Argument; import mltk.cmdline.CmdLineParser; import mltk.core.Instance; import mltk.core.Instances; import mltk.core.io.InstancesReader; import mltk.predictor.ProbabilisticClassifier; import mltk.predictor.Classifier; import mltk.predictor.Regressor; import mltk.predictor.io.PredictorReader; import mltk.util.OptimUtils; /** * Class for making evaluations. * * @author Yin Lou * */ public class Evaluator { /** * Returns the area under ROC curve. * * @param classifier a classifier that outputs probability. * @param instances the instances. * @return the area under ROC curve. */ public static double evalAreaUnderROC(ProbabilisticClassifier classifier, Instances instances) { double[] probs = new double[instances.size()]; double[] targets = new double[instances.size()]; for (int i = 0; i < probs.length; i++) { Instance instance = instances.get(i); probs[i] = classifier.predictProbabilities(instance)[1]; targets[i] = instance.getTarget(); } return new AUC().eval(probs, targets); } /** * Returns the root mean squared error. * * @param preds the predictions. * @param targets the targets. * @return the root mean squared error. */ public static double evalRMSE(List preds, List targets) { double rmse = 0; for (int i = 0; i < preds.size(); i++) { double d = targets.get(i) - preds.get(i); rmse += d * d; } rmse = Math.sqrt(rmse / preds.size()); return rmse; } /** * Returns the root mean squared error. * * @param regressor the regressor. * @param instances the instances. * @return the root mean squared error. */ public static double evalRMSE(Regressor regressor, Instances instances) { double rmse = 0; for (int i = 0; i < instances.size(); i++) { Instance instance = instances.get(i); double target = instance.getTarget(); double pred = regressor.regress(instance); double d = target - pred; rmse += d * d; } rmse = Math.sqrt(rmse / instances.size()); return rmse; } /** * Returns the classification error. * * @param classifier the classifier. * @param instances the instances. * @return the classification error. */ public static double evalError(Classifier classifier, Instances instances) { double error = 0; for (int i = 0; i < instances.size(); i++) { Instance instance = instances.get(i); double target = instance.getTarget(); double pred = classifier.classify(instance); if (target != pred) { error++; } } error /= instances.size(); return error; } /** * Returns the logistic loss. * * @param regressor the regressor. * @param instances the instances. * @return the logistic loss. */ public static double evalLogisticLoss(Regressor regressor, Instances instances) { double loss = 0; for (int i = 0; i < instances.size(); i++) { Instance instance = instances.get(i); double pred = regressor.regress(instance); loss += OptimUtils.computeLogisticLoss(pred, instance.getTarget()); } loss /= instances.size(); return loss; } /** * Returns the mean absolute error. * * @param regressor the regressor. * @param instances the instances. * @return the mean absolute error. */ public static double evalMAE(Regressor regressor, Instances instances) { double mae = 0; for (int i = 0; i < instances.size(); i++) { Instance instance = instances.get(i); double target = instance.getTarget(); double pred = regressor.regress(instance); double d = target - pred; mae += Math.abs(d); } mae /= instances.size(); return mae; } static class Options { @Argument(name = "-r", description = "attribute file path") String attPath = null; @Argument(name = "-d", description = "data set path", required = true) String dataPath = null; @Argument(name = "-m", description = "model path", required = true) String modelPath = null; @Argument(name = "-e", description = "AUC (a), Error (c), Logistic Loss (l), MAE(m), RMSE (r) (default: r)") String task = "r"; } /** * Evaluates a predictor. * *
	 * Usage: mltk.predictor.evaluation.Evaluator
	 * -d	data set path
	 * -m	model path
	 * [-r]	attribute file path
	 * [-e]	AUC (a), Error (c), Logistic Loss (l), MAE(m), RMSE (r) (default: r)
	 * 
* * @param args the command line arguments. * @throws Exception */ public static void main(String[] args) throws Exception { Options opts = new Options(); CmdLineParser parser = new CmdLineParser(Evaluator.class, opts); try { parser.parse(args); } catch (IllegalArgumentException e) { parser.printUsage(); System.exit(1); } Instances instances = InstancesReader.read(opts.attPath, opts.dataPath); mltk.predictor.Predictor predictor = PredictorReader.read(opts.modelPath); switch (opts.task) { case "a": double auc = evalAreaUnderROC((ProbabilisticClassifier) predictor, instances); System.out.println("AUC: " + auc); break; case "c": double error = evalError((Classifier) predictor, instances); System.out.println("Error: " + error); break; case "l": double logisticLoss = evalLogisticLoss((Regressor) predictor, instances); System.out.println("Logistic Loss: " + logisticLoss); break; case "m": double mae = evalMAE((Regressor) predictor, instances); System.out.println("MAE: " + mae); break; case "r": double rmse = evalRMSE((Regressor) predictor, instances); System.out.println("RMSE: " + rmse); break; default: break; } } } ================================================ FILE: src/main/java/mltk/predictor/evaluation/LogLoss.java ================================================ package mltk.predictor.evaluation; import mltk.core.Instances; import mltk.util.OptimUtils; /** * Class for evaluating log loss (cross entropy). * * @author Yin Lou * */ public class LogLoss extends SimpleMetric { protected boolean isRawScore; /** * Constructor. */ public LogLoss() { this(false); } /** * Constructor. * * @param isRawScore {@code true} if raw score is expected as input. */ public LogLoss(boolean isRawScore) { super(false); this.isRawScore = isRawScore; } @Override public double eval(double[] preds, double[] targets) { return OptimUtils.computeLogLoss(preds, targets, isRawScore); } @Override public double eval(double[] preds, Instances instances) { double logLoss = 0; for (int i = 0; i < preds.length; i++) { logLoss += OptimUtils.computeLogLoss(preds[i], instances.get(i).getTarget(), isRawScore); } logLoss /= preds.length; return logLoss; } /** * Returns {@code true} if raw score is expected as input. * * @return {@code true} if raw score is expected as input. */ public boolean isRawScore() { return isRawScore; } } ================================================ FILE: src/main/java/mltk/predictor/evaluation/LogisticLoss.java ================================================ package mltk.predictor.evaluation; import mltk.core.Instances; import mltk.util.OptimUtils; /** * Class for evaluating logistic loss. * * @author Yin Lou * */ public class LogisticLoss extends SimpleMetric { /** * Constructor. */ public LogisticLoss() { super(false); } @Override public double eval(double[] preds, double[] targets) { return OptimUtils.computeLogisticLoss(preds, targets); } @Override public double eval(double[] preds, Instances instances) { double logisticLoss = 0; for (int i = 0; i < preds.length; i++) { logisticLoss += OptimUtils.computeLogisticLoss(preds[i], instances.get(i).getTarget()); } logisticLoss /= preds.length; return logisticLoss; } } ================================================ FILE: src/main/java/mltk/predictor/evaluation/MAE.java ================================================ package mltk.predictor.evaluation; import mltk.core.Instances; /** * Class for evaluating mean absolute error (MAE). * * @author Yin Lou * */ public class MAE extends SimpleMetric { /** * Constructor. */ public MAE() { super(false); } @Override public double eval(double[] preds, double[] targets) { double mae = 0; for (int i = 0; i < preds.length; i++) { mae += Math.abs(targets[i] - preds[i]); } mae /= preds.length; return mae; } @Override public double eval(double[] preds, Instances instances) { double mae = 0; for (int i = 0; i < preds.length; i++) { mae += Math.abs(instances.get(i).getTarget() - preds[i]); } mae /= preds.length; return mae; } } ================================================ FILE: src/main/java/mltk/predictor/evaluation/Metric.java ================================================ package mltk.predictor.evaluation; import java.util.List; import mltk.core.Instances; import mltk.util.MathUtils; /** * Class for evaluation metrics. * * @author Yin Lou * */ public abstract class Metric { private boolean isLargerBetter; /** * Constructor. * * @param isLargerBetter {@code true} if larger value is better. */ public Metric(boolean isLargerBetter) { this.isLargerBetter = isLargerBetter; } /** * Returns {@code true} if larger value is better for this metric. * * @return {@code true} if larger value is better for this metric. */ public boolean isLargerBetter() { return isLargerBetter; } /** * Returns {@code true} if the first value is better. * * @param a the 1st value. * @param b the 2nd value. * @return {@code true} if the first value is better. */ public boolean isFirstBetter(double a, double b) { return MathUtils.isFirstBetter(a, b, isLargerBetter); } /** * Returns the worst value of this metric. * * @return the worst value of this metric. */ public double worstValue() { if (isLargerBetter) { return Double.NEGATIVE_INFINITY; } else { return Double.POSITIVE_INFINITY; } } /** * Evaluates predictions on a dataset. * * @param preds the predictions. * @param instances the dataset. * @return the evaluation measure. */ public abstract double eval(double[] preds, Instances instances); /** * Returns the index of best metric value in a list. * * @param list the list of metric values. * @return the index of best metric value in a list. */ public int searchBestMetricValueIndex(List list) { double bestSoFar = worstValue(); int idx = -1; for (int i = 0; i < list.size(); i++) { if (isFirstBetter(list.get(i), bestSoFar)) { bestSoFar = list.get(i); idx = i; } } return idx; } } ================================================ FILE: src/main/java/mltk/predictor/evaluation/MetricFactory.java ================================================ package mltk.predictor.evaluation; import java.util.HashMap; import java.util.Map; /** * Factory class for creating metrics. * * @author Yin Lou * */ public class MetricFactory { private static Map map; static { map = new HashMap<>(); map.put("auc", new AUC()); map.put("error", new Error()); map.put("logisticloss", new LogisticLoss()); map.put("logloss", new LogLoss(false)); map.put("logloss_t", new LogLoss(true)); map.put("mae", new MAE()); map.put("rmse", new RMSE()); } /** * Returns the metric. * * @param str the metric string. * @return the metric. */ public static Metric getMetric(String str) { String[] data = str.toLowerCase().split(":"); String name = data[0]; if (data.length == 1) { if (!map.containsKey(name)) { throw new IllegalArgumentException("Unrecognized metric name: " + name); } else { return map.get(name); } } else { if (name.equals("logloss")) { if (data[1].startsWith("t")) { return map.get("logloss_t"); } else { return map.get(name); } } else if (map.containsKey(name)) { return map.get(name); } } return null; } } ================================================ FILE: src/main/java/mltk/predictor/evaluation/Predictor.java ================================================ package mltk.predictor.evaluation; import java.io.IOException; import java.io.PrintWriter; import java.util.Arrays; import mltk.cmdline.Argument; import mltk.cmdline.CmdLineParser; import mltk.core.Instance; import mltk.core.Instances; import mltk.core.io.InstancesReader; import mltk.predictor.Classifier; import mltk.predictor.Learner.Task; import mltk.predictor.ProbabilisticClassifier; import mltk.predictor.Regressor; import mltk.predictor.io.PredictorReader; import mltk.util.OptimUtils; /** * Class for making predictions. * * @author Yin Lou * */ public class Predictor { static class Options { @Argument(name = "-r", description = "attribute file path") String attPath = null; @Argument(name = "-d", description = "data set path", required = true) String dataPath = null; @Argument(name = "-m", description = "model path", required = true) String modelPath = null; @Argument(name = "-p", description = "prediction path") String predictionPath = null; @Argument(name = "-R", description = "residual path") String residualPath = null; @Argument(name = "-g", description = "task between classification (c) and regression (r) (default: r)") String task = "r"; @Argument(name = "-P", description = "output probablity (default: false)") boolean prob = false; } /** * Makes predictions on a dataset. * *
	 * Usage: mltk.predictor.evaluation.Predictor
	 * -d	data set path
	 * -m	model path
	 * [-r]	attribute file path
	 * [-p]	prediction path
	 * [-R]	residual path
	 * [-g]	task between classification (c) and regression (r) (default: r)
	 * [-P]	output probability (default: false)
	 * 
* * @param args the command line arguments. * @throws Exception */ public static void main(String[] args) throws Exception { Options opts = new Options(); CmdLineParser parser = new CmdLineParser(Predictor.class, opts); Task task = null; try { parser.parse(args); task = Task.get(opts.task); } catch (IllegalArgumentException e) { parser.printUsage(); System.exit(1); } Instances instances = InstancesReader.read(opts.attPath, opts.dataPath); mltk.predictor.Predictor predictor = PredictorReader.read(opts.modelPath); switch (task) { case REGRESSION: Regressor regressor = (Regressor) predictor; double rmse = Evaluator.evalRMSE(regressor, instances); System.out.println("RMSE on Test: " + rmse); if (opts.predictionPath != null) { PrintWriter out = new PrintWriter(opts.predictionPath); for (Instance instance : instances) { double pred = regressor.regress(instance); out.println(pred); } out.flush(); out.close(); } if (opts.residualPath != null) { PrintWriter out = new PrintWriter(opts.residualPath); for (Instance instance : instances) { double pred = regressor.regress(instance); out.println(instance.getTarget() - pred); } out.flush(); out.close(); } break; case CLASSIFICATION: Classifier classifier = (Classifier) predictor; double error = Evaluator.evalError(classifier, instances); System.out.println("Error rate on Test: " + (error * 100) + " %"); if (opts.predictionPath != null) { if (opts.prob) { PrintWriter out = new PrintWriter(opts.predictionPath); ProbabilisticClassifier probClassifier = (ProbabilisticClassifier) predictor; for (Instance instance : instances) { double[] pred = probClassifier.predictProbabilities(instance); out.println(Arrays.toString(pred)); } out.flush(); out.close(); } else { PrintWriter out = new PrintWriter(opts.predictionPath); for (Instance instance : instances) { double pred = classifier.classify(instance); out.println((int) pred); } out.flush(); out.close(); } } if (opts.residualPath != null) { if (predictor instanceof Regressor) { PrintWriter out = new PrintWriter(opts.residualPath); Regressor regressingClassifier = (Regressor) predictor; for (Instance instance : instances) { double pred = regressingClassifier.regress(instance); int cls = (int) instance.getTarget(); out.println(OptimUtils.getPseudoResidual(pred, cls)); } out.flush(); out.close(); } else { System.out.println("Warning: Classifier does not support outputing pseudo residual."); } } break; default: break; } } /** * Makes predictions for a dataset. * * @param regressor the model. * @param instances the dataset. * @param path the output path. * @param residual {@code true} if residuals are the output. * @throws IOException */ public static void predict(Regressor regressor, Instances instances, String path, boolean residual) throws IOException { PrintWriter out = new PrintWriter(path); if (residual) { for (Instance instance : instances) { double pred = regressor.regress(instance); out.println(instance.getTarget() - pred); } } else { for (Instance instance : instances) { double pred = regressor.regress(instance); out.println(pred); } } out.flush(); out.close(); } /** * Makes predictions for a dataset. * * @param classifier the model. * @param instances the dataset. * @param path the output path. * @throws IOException */ public static void predict(Classifier classifier, Instances instances, String path) throws IOException { PrintWriter out = new PrintWriter(path); for (Instance instance : instances) { int pred = classifier.classify(instance); out.println(pred); } out.flush(); out.close(); } } ================================================ FILE: src/main/java/mltk/predictor/evaluation/RMSE.java ================================================ package mltk.predictor.evaluation; import mltk.core.Instances; /** * Class for evaluating root mean squared error (RMSE). * * @author Yin Lou * */ public class RMSE extends SimpleMetric { /** * Constructor. */ public RMSE() { super(false); } @Override public double eval(double[] preds, double[] targets) { double rmse = 0; for (int i = 0; i < preds.length; i++) { double d = targets[i] - preds[i]; rmse += d * d; } rmse = Math.sqrt(rmse / preds.length); return rmse; } @Override public double eval(double[] preds, Instances instances) { double rmse = 0; for (int i = 0; i < preds.length; i++) { double d = instances.get(i).getTarget() - preds[i]; rmse += d * d; } rmse = Math.sqrt(rmse / preds.length); return rmse; } } ================================================ FILE: src/main/java/mltk/predictor/evaluation/SimpleMetric.java ================================================ package mltk.predictor.evaluation; /** * Class for simple metrics. * * @author Yin Lou * */ public abstract class SimpleMetric extends Metric { /** * Constructor. * * @param isLargerBetter {@code true} if larger value is better. */ public SimpleMetric(boolean isLargerBetter) { super(isLargerBetter); } /** * Evaluates predictions given targets. * * @param preds the predictions. * @param targets the targets. * @return the evaluation measure. */ public abstract double eval(double[] preds, double[] targets); } ================================================ FILE: src/main/java/mltk/predictor/evaluation/package-info.java ================================================ /** * Provides classes for evaluating predictors. */ package mltk.predictor.evaluation; ================================================ FILE: src/main/java/mltk/predictor/function/Array1D.java ================================================ package mltk.predictor.function; import java.io.BufferedReader; import java.io.PrintWriter; import java.util.Arrays; import mltk.core.Instance; import mltk.predictor.Regressor; import mltk.util.ArrayUtils; /** * Class for 1D lookup tables. * * @author Yin Lou * */ public class Array1D implements Regressor, UnivariateFunction { /** * Attribute index. Must be binned/nominal attribute; otherwise the behavior is not guaranteed. */ protected int attIndex; /** * Predictions. */ protected double[] predictions; /** * Prediction on missing value. */ protected double predictionOnMV; /** * Constructor. */ public Array1D() { } /** * Constructs a 1D lookup table. * * @param attIndex the attribute index. The attribute must be discretized or nominal. * @param predictions the prediction array. */ public Array1D(int attIndex, double[] predictions) { this(attIndex, predictions, 0.0); } /** * Constructs a 1D lookup table. * * @param attIndex the attribute index. The attribute must be discretized or nominal. * @param predictions the prediction array. * @param predictionOnMV the prediction on missing value. */ public Array1D(int attIndex, double[] predictions, double predictionOnMV) { this.attIndex = attIndex; this.predictions = predictions; this.predictionOnMV = predictionOnMV; } /** * Returns the attribute index. * * @return the attribute index. */ public int getAttributeIndex() { return attIndex; } /** * Sets the attribute index. * * @param attIndex the new attribute index. */ public void setAttributeIndex(int attIndex) { this.attIndex = attIndex; } /** * Returns the internal prediction array. * * @return the internal prediction array. */ public double[] getPredictions() { return predictions; } /** * Sets the internal prediction array. * * @param predictions the new prediction array. */ public void setPredictions(double[] predictions) { this.predictions = predictions; } /** * Returns the prediction on missing value. * * @return the prediction on missing value. */ public double getPredictionOnMV() { return predictionOnMV; } /** * Sets prediction on missing value. * * @param predictionOnMV the prediction on missing value. */ public void setPredictionOnMV(double predictionOnMV) { this.predictionOnMV = predictionOnMV; } @Override public void read(BufferedReader in) throws Exception { String line = in.readLine(); String[] data = line.split(": "); attIndex = Integer.parseInt(data[1]); line = in.readLine(); data = line.split(": "); predictionOnMV = Double.parseDouble(data[1]); in.readLine(); predictions = ArrayUtils.parseDoubleArray(in.readLine()); } @Override public void write(PrintWriter out) throws Exception { out.printf("[Predictor: %s]\n", this.getClass().getCanonicalName()); out.println("AttIndex: " + attIndex); out.println("PredictionOnMV: " + predictionOnMV); out.println("Predictions: " + predictions.length); out.println(Arrays.toString(predictions)); } @Override public double regress(Instance instance) { double v = instance.getValue(attIndex); if (!Double.isNaN(v)) { return predictions[(int) v]; } else { return predictionOnMV; } } /** * Adds this lookup table with another one. * * @param ary the other lookup table. * @return this lookup table. */ public Array1D add(Array1D ary) { if (attIndex != ary.attIndex) { throw new IllegalArgumentException("Cannot add arrays on different terms"); } for (int i = 0; i < predictions.length; i++) { predictions[i] += ary.predictions[i]; } predictionOnMV += ary.predictionOnMV; return this; } @Override public double evaluate(double x) { if (!Double.isNaN(x)) { return predictions[(int) x]; } else { return predictionOnMV; } } @Override public Array1D copy() { double[] predictionsCopy = Arrays.copyOf(predictions, predictions.length); return new Array1D(attIndex, predictionsCopy, predictionOnMV); } } ================================================ FILE: src/main/java/mltk/predictor/function/Array2D.java ================================================ package mltk.predictor.function; import java.io.BufferedReader; import java.io.PrintWriter; import java.util.Arrays; import mltk.core.Instance; import mltk.predictor.Regressor; import mltk.util.ArrayUtils; import mltk.util.tuple.IntPair; /** * Class for 2D lookup tables. * * @author Yin Lou * */ public class Array2D implements Regressor, BivariateFunction { /** * First attribute index. */ protected int attIndex1; /** * Second attribute index. */ protected int attIndex2; /** * Predictions. */ protected double[][] predictions; /** * Predictions on missing value for attribute 1. */ protected double[] predictionsOnMV1; /** * Predictions on missing value for attribute 2. */ protected double[] predictionsOnMV2; /** * Prediction when both attribute 1 and 2 are missing. */ protected double predictionOnMV12; /** * Constructor. */ public Array2D() { } /** * Constructs a 2D lookup table. * * @param attIndex1 the 1st attribute index. The attribute must be discretized or nominal. * @param attIndex2 the 2nd attribute index. The attribute must be discretized or nominal. * @param predictions the prediction matrix. */ public Array2D(int attIndex1, int attIndex2, double[][] predictions) { this.attIndex1 = attIndex1; this.attIndex2 = attIndex2; this.predictions = predictions; } /** * Constructs a 2D lookup table. * * @param attIndex1 the 1st attribute index. The attribute must be discretized or nominal. * @param attIndex2 the 2nd attribute index. The attribute must be discretized or nominal. * @param predictions the prediction matrix. * @param predictionsOnMV1 the prediction array when the 1st attribute is missing. * @param predictionsOnMV2 the prediction array when the 2nd attribute is missing. * @param predictionOnMV12 the prediction when both attributes are missing. */ public Array2D(int attIndex1, int attIndex2, double[][] predictions, double[] predictionsOnMV1, double[] predictionsOnMV2, double predictionOnMV12) { this.attIndex1 = attIndex1; this.attIndex2 = attIndex2; this.predictions = predictions; this.predictionsOnMV1 = predictionsOnMV1; this.predictionsOnMV2 = predictionsOnMV2; this.predictionOnMV12 = predictionOnMV12; } /** * Returns the index of 1st attribute. * * @return the index of 1st attribute. */ public int getAttributeIndex1() { return attIndex1; } /** * Returns the index of 2nd attribute. * * @return the index of 2nd attribute. */ public int getAttributeIndex2() { return attIndex2; } /** * Returns the attribute indices pair. * * @return the attribute indices pair. */ public IntPair getAttributeIndices() { return new IntPair(attIndex1, attIndex2); } /** * Sets the attribute indices. * * @param attIndex1 the new 1st attribute index. * @param attIndex2 the new 2nd attribute index. */ public void setAttributeIndices(int attIndex1, int attIndex2) { this.attIndex1 = attIndex1; this.attIndex2 = attIndex2; } /** * Returns the internal prediction matrix. * * @return the internal prediction matrix. */ public double[][] getPredictions() { return predictions; } /** * Sets the internal prediction matrix. * * @param predictions the new prediction matrix. */ public void setPredictions(double[][] predictions) { this.predictions = predictions; } /** * Returns the internal prediction array when the 1st attribute is missing. * * @return the internal prediction array when the 1st attribute is missing. */ public double[] getPredictionsOnMV1() { return predictionsOnMV1; } /** * Sets the internal prediction array when the 1st attribute is missing. * * @param predictionsOnMV1 the new prediction array. */ public void setPredictionsOnMV1(double[] predictionsOnMV1) { this.predictionsOnMV1 = predictionsOnMV1; } /** * Returns the internal prediction array when the 2nd attribute is missing. * * @return the internal prediction array when the 2nd attribute is missing. */ public double[] getPredictionsOnMV2() { return predictionsOnMV2; } /** * Sets the internal prediction array when the 2nd attribute is missing. * * @param predictionsOnMV2 the new prediction array. */ public void setPredictionsOnMV2(double[] predictionsOnMV2) { this.predictionsOnMV2 = predictionsOnMV2; } /** * Returns the prediction when both attributes are missing. * * @return the prediction when both attributes are missing. */ public double getPredictionOnMV12() { return predictionOnMV12; } /** * Sets the prediction when both attributes are missing. * * @param predictionOnMV12 the new prediction. */ public void setPredictionOnMV12(double predictionOnMV12) { this.predictionOnMV12 = predictionOnMV12; } @Override public void read(BufferedReader in) throws Exception { String line = in.readLine(); String[] data = line.split(": "); attIndex1 = Integer.parseInt(data[1]); line = in.readLine(); data = line.split(": "); attIndex2 = Integer.parseInt(data[1]); String[] dim = in.readLine().split(": ")[1].split("x"); predictions = new double[Integer.parseInt(dim[0])][]; for (int i = 0; i < predictions.length; i++) { predictions[i] = ArrayUtils.parseDoubleArray(in.readLine()); } in.readLine(); predictionsOnMV1 = ArrayUtils.parseDoubleArray(in.readLine()); in.readLine(); predictionsOnMV2 = ArrayUtils.parseDoubleArray(in.readLine()); data = in.readLine().split(": "); predictionOnMV12 = Double.parseDouble(data[1]); } @Override public void write(PrintWriter out) throws Exception { out.printf("[Predictor: %s]\n", this.getClass().getCanonicalName()); out.println("AttIndex1: " + attIndex1); out.println("AttIndex2: " + attIndex2); out.println("Predictions: " + predictions.length + "x" + predictions[0].length); for (int i = 0; i < predictions.length; i++) { out.println(Arrays.toString(predictions[i])); } out.println("PredictionsOnMV1: " + predictionsOnMV1.length); out.println(Arrays.toString(predictionsOnMV1)); out.println("PredictionsOnMV2: " + predictionsOnMV2.length); out.println(Arrays.toString(predictionsOnMV2)); out.println("PredictionOnMV12: " + predictionOnMV12); } @Override public double regress(Instance instance) { return evaluate(instance.getValue(attIndex1), instance.getValue(attIndex2)); } /** * Adds this lookup table with another one. * * @param ary the other lookup table. * @return this lookup table. */ public Array2D add(Array2D ary) { if (attIndex1 != ary.attIndex1 || attIndex2 != ary.attIndex2) { throw new IllegalArgumentException("Cannot add arrays on differnt terms"); } for (int i = 0; i < predictions.length; i++) { predictionsOnMV2[i] += ary.predictionsOnMV2[i]; double[] preds1 = predictions[i]; double[] preds2 = ary.predictions[i]; for (int j = 0; j < preds1.length; j++) { predictionsOnMV1[j] += ary.predictionsOnMV1[j]; preds1[j] += preds2[j]; } } predictionOnMV12 += ary.predictionOnMV12; return this; } @Override public double evaluate(double x, double y) { if (!Double.isNaN(x) && !Double.isNaN(y)) { return predictions[(int) x][(int) y]; } else if (Double.isNaN(x) && !Double.isNaN(y)) { return predictionsOnMV1[(int) y]; } else if (!Double.isNaN(x) && Double.isNaN(y)) { return predictionsOnMV2[(int) x]; } else { return predictionOnMV12; } } @Override public Array2D copy() { double[][] predictionsCopy = new double[predictions.length][]; for (int i = 0; i < predictionsCopy.length; i++) { predictionsCopy[i] = Arrays.copyOf(predictions[i], predictions[i].length); } double[] predictionsOnMV1Copy = Arrays.copyOf(predictionsOnMV1, predictionsOnMV1.length); double[] predictionsOnMV2Copy = Arrays.copyOf(predictionsOnMV2, predictionsOnMV2.length); return new Array2D(attIndex1, attIndex2, predictionsCopy, predictionsOnMV1Copy, predictionsOnMV2Copy, predictionOnMV12); } } ================================================ FILE: src/main/java/mltk/predictor/function/BaggedLineCutter.java ================================================ package mltk.predictor.function; import java.util.ArrayList; import java.util.Collections; import java.util.List; import mltk.core.Attribute; import mltk.core.BinnedAttribute; import mltk.core.Instance; import mltk.core.Instances; import mltk.core.NominalAttribute; import mltk.core.Sampling; import mltk.predictor.BaggedEnsemble; import mltk.util.tuple.IntPair; import mltk.util.Element; import mltk.util.MathUtils; /** * Class for cutting lines with bagging. * * @author Yin Lou * */ public class BaggedLineCutter extends EnsembledLineCutter { private List samples; /** * Constructor. */ public BaggedLineCutter() { this(false); } /** * Constructor. * * @param isClassification {@code true} if it is a classification problem. */ public BaggedLineCutter(boolean isClassification) { attIndex = -1; this.isClassification = isClassification; } /** * Creates internal bootstrap samples. * * @param n the size of the dataset to sample. * @param baggingIters the number of bagging iterations. */ public void createBags(int n, int baggingIters) { samples = new ArrayList<>(baggingIters); if (baggingIters <= 0) { // No bagging IntPair[] indices = new IntPair[n]; for (int i = 0; i < n; i++) { indices[i] = new IntPair(i, 1); } samples.add(indices); } else { for (int b = 0; b < baggingIters; b++) { samples.add(Sampling.createBootstrapSampleIndices(n)); } } } @Override public BaggedEnsemble build(Instances instances) { return build(instances, attIndex, numIntervals); } @Override public BaggedEnsemble build(Instances instances, Attribute attribute, int numIntervals) { int attIndex = attribute.getIndex(); if (samples == null) { createBags(instances.size(), baggingIters); } BaggedEnsemble ensemble = new BaggedEnsemble(samples.size()); double[] targets = new double[instances.size()]; double[] fvalues = new double[instances.size()]; double[] weights = new double[instances.size()]; for (int i = 0; i < instances.size(); i++) { Instance instance = instances.get(i); targets[i] = instance.getTarget(); fvalues[i] = instance.getValue(attribute); weights[i] = instance.getWeight(); } if (attribute.getType() == Attribute.Type.NUMERIC) { for (IntPair[] indices : samples) { double sumRespOnMV = 0.0; double sumWeightOnMV = 0.0; List> pairs = new ArrayList<>(instances.size()); for (IntPair entry : indices) { int index = entry.v1; int w = entry.v2; double weight = weights[index]; double value = fvalues[index]; double target = targets[index]; if (!Double.isNaN(value)) { if (isClassification) { pairs.add(new Element(new double[] { target * w, weight * w }, value)); } else { pairs.add(new Element(new double[] { target * weight * w, weight * w }, value)); } } else { if (isClassification) { sumRespOnMV += target * w; } else { sumRespOnMV += target * weight * w; } sumWeightOnMV += weight * w; } } Collections.sort(pairs); List histograms = new ArrayList<>(); LineCutter.getHistograms(pairs, histograms); histograms.add(new double[] { Double.NaN, sumRespOnMV, sumWeightOnMV }); Function1D func = LineCutter.build(attIndex, histograms, numIntervals); ensemble.add(func); } } else { int size = 0; if (attribute.getType() == Attribute.Type.BINNED) { size = ((BinnedAttribute) attribute).getNumBins(); } else { size = ((NominalAttribute) attribute).getCardinality(); } for (IntPair[] indices : samples) { List histograms; double sumRespOnMV = 0.0; double sumWeightOnMV = 0.0; double[][] histogram = new double[size][3]; for (IntPair entry : indices) { int index = entry.v1; int w = entry.v2; double weight = weights[index]; double value = fvalues[index]; double target = targets[index]; if (!Double.isNaN(value)) { int idx = (int) value; if (isClassification) { histogram[idx][0] += target * w; } else { histogram[idx][0] += target * weight * w; } histogram[idx][1] += weight * w; } else { if (isClassification) { sumRespOnMV += target * w; } else { sumRespOnMV += target * weight * w; } sumWeightOnMV += weight * w; } } histograms = new ArrayList<>(histogram.length + 1); for (int i = 0; i < histogram.length; i++) { if (!MathUtils.isZero(histogram[i][1])) { double[] hist = histogram[i]; histograms.add(new double[] {i, hist[0], hist[1]}); } } histograms.add(new double[] { Double.NaN, sumRespOnMV, sumWeightOnMV }); Function1D func = LineCutter.build(attIndex, histograms, numIntervals); ensemble.add(func); } } return ensemble; } } ================================================ FILE: src/main/java/mltk/predictor/function/BivariateFunction.java ================================================ package mltk.predictor.function; /** * Interface for bivariate real functions. * * @author Yin Lou * */ public interface BivariateFunction { /** * Computes the value for the function. * * @param x the 1st argument. * @param y the 2nd argument. * @return the value for the function. */ public double evaluate(double x, double y); } ================================================ FILE: src/main/java/mltk/predictor/function/CHistogram.java ================================================ package mltk.predictor.function; /** * Class for cumulative histograms. * * @author Yin Lou * */ public class CHistogram { public double[] sum; public double[] count; public double sumOnMV; public double countOnMV; /** * Constructor. * * @param n the size of this cumulative histogram. */ public CHistogram(int n) { sum = new double[n]; count = new double[n]; sumOnMV = 0.0; countOnMV = 0.0; } /** * Returns the size of this cumulative histogram. * * @return the size of this cumulative histogram. */ public int size() { return sum.length; } /** * Returns {@code true} if missing values are present. * * @return {@code true} if missing values are present. */ public boolean hasMissingValue() { return countOnMV > 0; } } ================================================ FILE: src/main/java/mltk/predictor/function/CompressionUtils.java ================================================ package mltk.predictor.function; import mltk.core.Instance; import mltk.predictor.BaggedEnsemble; import mltk.predictor.BoostedEnsemble; import mltk.predictor.Predictor; import mltk.util.tuple.IntPair; /** * Class for utility functions for compressing ensembles of univariate/bivariate functions. * * @author Yin Lou * */ public class CompressionUtils { /** * Compresses a bagged ensemble of 1D functions to a single 1D function. * * @param attIndex the attribute index of this regressor. * @param baggedEnsemble the bagged ensemble. * @return a single compressed 1D function. */ public static Function1D compress(int attIndex, BaggedEnsemble baggedEnsemble) { Function1D function = Function1D.getConstantFunction(attIndex, 0); for (int i = 0; i < baggedEnsemble.size(); i++) { Predictor predictor = baggedEnsemble.get(i); Function1D func = null; if (predictor instanceof Function1D) { func = (Function1D) predictor; } else { throw new IllegalArgumentException(); } function.add(func); } function.divide(baggedEnsemble.size()); return function; } /** * Compresses a boosted ensemble to a single 1D function. * * @param attIndex the attribute of this regressor. * @param boostedEnsemble the boosted ensemble. * @return a single compressed 1D function. */ public static Function1D compress(int attIndex, BoostedEnsemble boostedEnsemble) { Function1D function = Function1D.getConstantFunction(attIndex, 0); for (int i = 0; i < boostedEnsemble.size(); i++) { Predictor predictor = boostedEnsemble.get(i); Function1D func = null; if (predictor instanceof Function1D) { func = (Function1D) predictor; } else if (predictor instanceof BaggedEnsemble) { func = compress(attIndex, (BaggedEnsemble) predictor); } else { throw new IllegalArgumentException(); } function.add(func); } return function; } /** * Converts a 1D function to 1D lookup table. * * @param n the number of bins. * @param function the 1D function. * @return a 1D lookup table. */ public static Array1D convert(int n, Function1D function) { double[] predictions = new double[n]; for (int i = 0; i < predictions.length; i++) { predictions[i] = function.evaluate(i); } return new Array1D(function.getAttributeIndex(), predictions, function.predictionOnMV); } /** * Compresses a bagged ensemble of 2D functions to a single 2D function. * * @param attIndex1 the 1st attribute index. * @param attIndex2 the 2nd attribute index. * @param baggedEnsemble the bagged ensemble. * @return a single compressed 2D functions. */ public static Function2D compress(int attIndex1, int attIndex2, BaggedEnsemble baggedEnsemble) { // TODO check consistency problem when missing values are present. Function2D function = Function2D.getConstantFunction(attIndex1, attIndex2, 0); for (int i = 0; i < baggedEnsemble.size(); i++) { Predictor predictor = baggedEnsemble.get(i); Function2D func = null; if (predictor instanceof Function2D) { func = (Function2D) predictor; } else { throw new IllegalArgumentException(); } function.add(func); } function.divide(baggedEnsemble.size()); return function; } /** * Compresses a boosted ensemble to a single 2D function. * * @param attIndex1 the 1st attribute index. * @param attIndex2 the 2nd attribute index. * @param boostedEnsemble the boosted ensemble. * @return a single compressed 2D function. */ public static Function2D compress(int attIndex1, int attIndex2, BoostedEnsemble boostedEnsemble) { // TODO check consistency problem when missing values are present. Function2D function = Function2D.getConstantFunction(attIndex1, attIndex2, 0); for (int i = 0; i < boostedEnsemble.size(); i++) { Predictor predictor = boostedEnsemble.get(i); Function2D func = null; if (predictor instanceof Function2D) { func = (Function2D) predictor; } else if (predictor instanceof BaggedEnsemble) { func = compress(attIndex1, attIndex2, (BaggedEnsemble) predictor); } else { throw new IllegalArgumentException(); } function.add(func); } return function; } /** * Compresses and converts a boosted ensemble to a single 2D lookup table. * * @param attIndex1 the 1st attribute index. * @param attIndex2 the 2nd attribute index. * @param boostedEnsemble the boosted ensemble. * @return a 2D lookup table. */ public static Array2D compress(int attIndex1, int attIndex2, int n1, int n2, BoostedEnsemble boostedEnsemble) { double[][] predictions = new double[n1][n2]; double[] predictionsOnMV1 = new double[n2]; double[] predictionsOnMV2 = new double[n1]; double[] vector = new double[Math.max(attIndex1, attIndex2) + 1]; Instance instance = new Instance(vector); for (int i = 0; i < n1; i++) { vector[attIndex1] = i; vector[attIndex2] = Double.NaN; predictionsOnMV2[i] = boostedEnsemble.regress(instance); double[] preds = predictions[i]; for (int j = 0; j < n2; j++) { vector[attIndex1] = Double.NaN; vector[attIndex2] = j; predictionsOnMV1[j] = boostedEnsemble.regress(instance); vector[attIndex1] = i; preds[j] = boostedEnsemble.regress(instance); } } vector[attIndex1] = Double.NaN; vector[attIndex2] = Double.NaN; return new Array2D(attIndex1, attIndex2, predictions, predictionsOnMV1, predictionsOnMV2, boostedEnsemble.regress(instance)); } /** * Converts a 2D function to 2D lookup table. * * @param n1 the number of bins for 1st attribute. * @param n2 the number of bins for 2nd attribute. * @param function the 2D function. * @return a 2D lookup table. */ public static Array2D convert(int n1, int n2, Function2D function) { double[][] predictions = new double[n1][n2]; double[] predictionsOnMV1 = new double[n2]; double[] predictionsOnMV2 = new double[n1]; for (int i = 0; i < n1; i++) { predictionsOnMV2[i] = function.evaluate(i, Double.NaN); double[] preds = predictions[i]; for (int j = 0; j < n2; j++) { predictionsOnMV1[j] = function.evaluate(Double.NaN, j); preds[j] = function.evaluate(i, j); } } IntPair attIndices = function.getAttributeIndices(); return new Array2D(attIndices.v1, attIndices.v2, predictions, predictionsOnMV1, predictionsOnMV2, function.predictionOnMV12); } } ================================================ FILE: src/main/java/mltk/predictor/function/CubicSpline.java ================================================ package mltk.predictor.function; import java.io.BufferedReader; import java.io.PrintWriter; import java.util.Arrays; import mltk.core.Instance; import mltk.predictor.Regressor; import mltk.util.ArrayUtils; /** * Class for cubic splines. Given knots k, the cubic spline uses the following basis: 1, x, x^2, x^3, (x - k[i])_+^3. * * @author Yin Lou * */ public class CubicSpline implements Regressor, UnivariateFunction { protected int attIndex; protected double intercept; protected double[] knots; protected double[] w; /** * Constructor. * * @param attIndex the attribute index. * @param intercept the intercept. * @param knots the knots. * @param w the coefficient vector. */ public CubicSpline(int attIndex, double intercept, double[] knots, double[] w) { this.attIndex = attIndex; this.intercept = intercept; this.knots = knots; this.w = w; } /** * Constructor. * * @param intercept the intercept. * @param knots the knots. * @param w the coefficient vector. */ public CubicSpline(double intercept, double[] knots, double[] w) { this(-1, intercept, knots, w); } /** * Constructor. * * @param knots the knots. * @param w the coefficient vector. */ public CubicSpline(double[] knots, double[] w) { this(-1, 0, knots, w); } /** * Construct a cubic spline with specified knots and all coefficients set to zero. * * @param knots the knots. */ public CubicSpline(double[] knots) { this(-1, 0, knots, new double[knots.length + 3]); } /** * Constructor. */ public CubicSpline() { } @Override public void read(BufferedReader in) throws Exception { attIndex = Integer.parseInt(in.readLine().split(": ")[1]); intercept = Double.parseDouble(in.readLine().split(": ")[1]); in.readLine(); knots = ArrayUtils.parseDoubleArray(in.readLine()); in.readLine(); w = ArrayUtils.parseDoubleArray(in.readLine()); } @Override public void write(PrintWriter out) throws Exception { out.printf("[Predictor: %s]\n", this.getClass().getCanonicalName()); out.println("AttIndex: " + attIndex); out.println("Intercept: " + intercept); out.println("Knots: " + knots.length); out.println(Arrays.toString(knots)); out.println("Coefficients: " + w.length); out.println(Arrays.toString(w)); } @Override public double evaluate(double x) { double pred = intercept + w[0] * x + w[1] * x * x + w[2] * x * x * x; for (int i = 0; i < knots.length; i++) { pred += h(x, knots[i]) * w[i + 3]; } return pred; } /** * Calculate the basis. * * @param x a real. * @param k a knot. * @return h(x, z), a basis in cubic spline. */ public static double h(double x, double k) { double t = x - k; if (t < 0) { return 0; } return t * t * t; } @Override public double regress(Instance instance) { return evaluate(instance.getValue(attIndex)); } /** * Returns the attribute index. * * @return the attribute index. */ public int getAttributeIndex() { return attIndex; } /** * Sets the attribute index. * * @param attIndex the new attribute index. */ public void setAttributeIndex(int attIndex) { this.attIndex = attIndex; } /** * Returns the coefficient vector. * * @return the coefficient vector. */ public double[] getCoefficients() { return w; } /** * Returns the knot vector. * * @return the knot vector. */ public double[] getKnots() { return knots; } /** * Returns the intercept. * * @return the intercept. */ public double getIntercept() { return intercept; } @Override public CubicSpline copy() { double[] newW = w.clone(); double[] newKnots = knots.clone(); return new CubicSpline(attIndex, intercept, newW, newKnots); } } ================================================ FILE: src/main/java/mltk/predictor/function/EnsembledLineCutter.java ================================================ package mltk.predictor.function; import mltk.core.Attribute; import mltk.core.Instances; import mltk.predictor.BaggedEnsemble; import mltk.predictor.Learner; public abstract class EnsembledLineCutter extends Learner { protected int attIndex; protected int numIntervals; protected int baggingIters; protected boolean isClassification; @Override public BaggedEnsemble build(Instances instances) { return build(instances, attIndex, numIntervals); } /** * Builds an 1D function ensemble. * * @param instances the training set. * @param attIndex the attribute index. * @param numIntervals the number of intervals. * @return an 1D function ensemble. */ public BaggedEnsemble build(Instances instances, int attIndex, int numIntervals) { Attribute attribute = instances.getAttributes().get(attIndex); return build(instances, attribute, numIntervals); } public abstract BaggedEnsemble build(Instances instances, Attribute attribute, int numIntervals); /** * Returns the index in the attribute list of the training set. * * @return the index in the attribute list of the training set. */ public int getAttributeIndex() { return attIndex; } /** * Sets the index in the attribute list of the training set. * * @param attIndex the attribute index. */ public void setAttributeIndex(int attIndex) { this.attIndex = attIndex; } /** * Returns the number of bagging iterations. * * @return the number of bagging iterations. */ public int getBaggingIters() { return baggingIters; } /** * Sets the number of bagging iterations. * * @param baggingIters the number of bagging iterations. */ public void setBaggingIters(int baggingIters) { this.baggingIters = baggingIters; } /** * Returns {@code true} if it is a classification problem. * * @return {@code true} if it is a classification problem. */ public boolean isClassification() { return isClassification; } /** * Sets {@code true} if it is a classification problem. * * @param isClassification {@code true} if it is a classification problem. */ public void setClassification(boolean isClassification) { this.isClassification = isClassification; } /** * Returns the number of intervals. * * @return the number of intervals. */ public int getNumIntervals() { return numIntervals; } /** * Sets the number of intervals. * * @param numIntervals the number of intervals. */ public void setNumIntervals(int numIntervals) { this.numIntervals = numIntervals; } } ================================================ FILE: src/main/java/mltk/predictor/function/Function1D.java ================================================ package mltk.predictor.function; import java.io.BufferedReader; import java.io.PrintWriter; import java.util.Arrays; import mltk.core.Instance; import mltk.predictor.Regressor; import mltk.util.ArrayUtils; import mltk.util.MathUtils; import mltk.util.VectorUtils; /** * Class for 1D functions. * *

* This class represents a segmented 1D function. Segments are defined in split array. For example, [3, 5, +INF] defines * three segments: (-INF, 3], (3, 5], (5, +INF). The last value in the split array is always +INF. The prediction array * is the corresponding predictions for segments defined in splits. *

* * @author Yin Lou * */ public class Function1D implements Regressor, UnivariateFunction { /** * Attribute index. */ protected int attIndex; /** * Last value is always Double.POSITIVE_INFINITY. e.g. [3, 5, +INF] defines three segments: (-INF, 3], (3, 5], (5, * +INF) */ protected double[] splits; /** * Corresponding predictions for segments defined in splits. */ protected double[] predictions; /** * Prediction on missing value. */ protected double predictionOnMV; /** * Returns a constant 1D function. * * @param attIndex the attribute index of this function. * @param prediction the constant. * @return a constant 1D function. */ public static Function1D getConstantFunction(int attIndex, double prediction) { Function1D func = new Function1D(attIndex, new double[] { Double.POSITIVE_INFINITY }, new double[] { prediction }); return func; } /** * Resets this function to 0. */ public void setZero() { splits = new double[] { Double.POSITIVE_INFINITY }; predictions = new double[] { 0 }; predictionOnMV = 0.0; } /** * Returns {@code true} if the function is 0. * * @return {@code true} if the function is 0. */ public boolean isZero() { return ArrayUtils.isConstant(predictions, 0, predictions.length, 0) && MathUtils.isZero(predictionOnMV); } /** * Returns {@code true} if the function is constant. * * @return {@code true} if the function is constant. */ public boolean isConstant() { return ArrayUtils.isConstant(predictions, 1, predictions.length, predictions[0]) && MathUtils.isZero(predictionOnMV - predictions[0]); } /** * Multiplies this function with a constant. * * @param c the constant. * @return this function. */ public Function1D multiply(double c) { VectorUtils.multiply(predictions, c); predictionOnMV *= c; return this; } /** * Divides this function with a constant. * * @param c the constant. * @return this function. */ public Function1D divide(double c) { VectorUtils.divide(predictions, c); predictionOnMV /= c; return this; } /** * Adds this function with a constant. * * @param c the constant. * @return this function. */ public Function1D add(double c) { VectorUtils.add(predictions, c); predictionOnMV += c; return this; } /** * Subtracts this function with a constant. * * @param c the constant. * @return this function. */ public Function1D subtract(double c) { VectorUtils.subtract(predictions, c); predictionOnMV -= c; return this; } /** * Constructor. */ public Function1D() { } /** * Constructor. * * @param attIndex the attribute index. * @param splits the splits. * @param predictions the predictions. */ public Function1D(int attIndex, double[] splits, double[] predictions) { this(attIndex, splits, predictions, 0.0); } /** * Constructor. * * @param attIndex the attribute index. * @param splits the splits. * @param predictions the predictions. * @param predictionOnMissing prediction on missing value; */ public Function1D(int attIndex, double[] splits, double[] predictions, double predictionOnMissing) { this.attIndex = attIndex; this.splits = splits; this.predictions = predictions; this.predictionOnMV = predictionOnMissing; } /** * Adds this function whit another function. * * @param func the other function. * @return this function. */ public Function1D add(Function1D func) { if (attIndex != func.attIndex) { throw new IllegalArgumentException("Cannot add functions on different terms"); } int[] insertionPoints = new int[func.splits.length - 1]; int newElements = 0; for (int i = 0; i < insertionPoints.length; i++) { insertionPoints[i] = Arrays.binarySearch(splits, func.splits[i]); if (insertionPoints[i] < 0) { newElements++; } } if (newElements > 0) { double[] newSplits = new double[splits.length + newElements]; System.arraycopy(splits, 0, newSplits, 0, splits.length); int k = splits.length; for (int i = 0; i < insertionPoints.length; i++) { if (insertionPoints[i] < 0) { newSplits[k++] = func.splits[i]; } } Arrays.sort(newSplits); double[] newPredictions = new double[newSplits.length]; for (int i = 0; i < newSplits.length; i++) { newPredictions[i] = this.evaluate(newSplits[i]) + func.evaluate(newSplits[i]); } splits = newSplits; predictions = newPredictions; } else { for (int i = 0; i < splits.length; i++) { predictions[i] += func.evaluate(splits[i]); } } this.predictionOnMV += func.predictionOnMV; return this; } /** * Returns the attribute index. * * @return the attribute index. */ public int getAttributeIndex() { return attIndex; } /** * Sets the attribute index. * * @param attIndex the new attribute index. */ public void setAttributeIndex(int attIndex) { this.attIndex = attIndex; } /** * Returns the internal split array. * * @return the internal split array. */ public double[] getSplits() { return splits; } /** * Sets the split array. * * @param splits the new split array. */ public void setSplits(double[] splits) { this.splits = splits; } /** * Returns the internal prediction array. * * @return the internal prediction array. */ public double[] getPredictions() { return predictions; } /** * Sets the prediction array. * * @param predictions the new prediction array. */ public void setPredictions(double[] predictions) { this.predictions = predictions; } /** * Returns the prediction on missing value. * * @return the prediction on missing value. */ public double getPredictionOnMV() { return predictionOnMV; } /** * Sets the prediction on missing value. * * @param predictionOnMV the prediction on missing value. */ public void setPredictionOnMV(double predictionOnMV) { this.predictionOnMV = predictionOnMV; } @Override public void read(BufferedReader in) throws Exception { String line = in.readLine(); String[] data = line.split(": "); attIndex = Integer.parseInt(data[1]); line = in.readLine(); data = line.split(": "); predictionOnMV = Double.parseDouble(data[1]); in.readLine(); line = in.readLine(); splits = ArrayUtils.parseDoubleArray(line); in.readLine(); line = in.readLine(); predictions = ArrayUtils.parseDoubleArray(line); } @Override public void write(PrintWriter out) throws Exception { out.printf("[Predictor: %s]\n", this.getClass().getCanonicalName()); out.println("AttIndex: " + attIndex); out.println("PredictinOnMV: " + predictionOnMV); out.println("Splits: " + splits.length); out.println(Arrays.toString(splits)); out.println("Predictions: " + predictions.length); out.println(Arrays.toString(predictions)); } @Override public double regress(Instance instance) { return evaluate(instance.getValue(attIndex)); } @Override public double evaluate(double x) { if (Double.isNaN(x)) { return predictionOnMV; } else { return predictions[getSegmentIndex(x)]; } } @Override public Function1D copy() { double[] splitsCopy = Arrays.copyOf(splits, splits.length); double[] predictionsCopy = Arrays.copyOf(predictions, predictions.length); return new Function1D(attIndex, splitsCopy, predictionsCopy, predictionOnMV); } /** * Returns the segment index given x. * * @param x the search key. * @return the segment index given x. */ protected int getSegmentIndex(double x) { // Assume x is not NaN return ArrayUtils.findInsertionPoint(splits, x); } } ================================================ FILE: src/main/java/mltk/predictor/function/Function2D.java ================================================ package mltk.predictor.function; import java.io.BufferedReader; import java.io.PrintWriter; import java.util.Arrays; import mltk.core.Instance; import mltk.predictor.Regressor; import mltk.util.ArrayUtils; import mltk.util.VectorUtils; import mltk.util.tuple.IntPair; /** * Class for 2D functions. * *

* This class represents a segmented 2D function. Segments are defined in split arrays for the two attributes. For * example, [3, 5, +INF] defines three segments: (-INF, 3], (3, 5], (5, +INF). The last value in the split array is * always +INF. The prediction matrix is the corresponding predictions for segments defined in splits. *

* * @author Yin Lou * */ public class Function2D implements Regressor, BivariateFunction { /** * First attribute index. */ protected int attIndex1; /** * Second attribute index. */ protected int attIndex2; /** * Last value is always Double.POSITIVE_INFINITY. e.g. [3, 5, +INF] defines three segments: (-INF, 3], (3, 5], (5, * +INF) */ protected double[] splits1; protected double[] splits2; /** * Predictions. */ protected double[][] predictions; /** * Predictions on missing value for attribute 1. */ protected double[] predictionsOnMV1; /** * Predictions on missing value for attribute 2. */ protected double[] predictionsOnMV2; /** * Prediction when both attribute 1 and 2 are missing. */ protected double predictionOnMV12; /** * Constructor. */ public Function2D() { } /** * Constructor. * * @param attIndex1 the 1st attribute index. * @param attIndex2 the 2nd attribute index. * @param splits1 the split array for the 1st attribute. * @param splits2 the split array for the 2nd attribute. * @param predictions the prediction matrix. */ public Function2D(int attIndex1, int attIndex2, double[] splits1, double[] splits2, double[][] predictions) { this(attIndex1, attIndex2, splits1, splits2, predictions, new double[splits2.length], new double[splits1.length], 0.0); } /** * Constructor. * * @param attIndex1 the 1st attribute index. * @param attIndex2 the 2nd attribute index. * @param splits1 the split array for the 1st attribute. * @param splits2 the split array for the 2nd attribute. * @param predictions the prediction matrix. * @param predictionsOnMV1 the prediction array when the 1st attribute is missing. * @param predictionsOnMV2 the prediction array when the 2nd attribute is missing. * @param predictionOnMV12 the prediction when both attributes are missing. */ public Function2D(int attIndex1, int attIndex2, double[] splits1, double[] splits2, double[][] predictions, double[] predictionsOnMV1, double[] predictionsOnMV2, double predictionOnMV12) { this.attIndex1 = attIndex1; this.attIndex2 = attIndex2; this.predictions = predictions; this.splits1 = splits1; this.splits2 = splits2; this.predictionsOnMV1 = predictionsOnMV1; this.predictionsOnMV2 = predictionsOnMV2; this.predictionOnMV12 = predictionOnMV12; } /** * Returns a constant 2D function. * * @param attIndex1 the 1st attribute index. * @param attIndex2 the 2nd attribute index. * @param prediction the constant. * @return a constant 2D function. */ public static Function2D getConstantFunction(int attIndex1, int attIndex2, double prediction) { Function2D func = new Function2D(attIndex1, attIndex2, new double[] { Double.POSITIVE_INFINITY }, new double[] { Double.POSITIVE_INFINITY }, new double[][] { { prediction } }); return func; } /** * Returns the index of 1st attribute. * * @return the index of 1st attribute. */ public int getAttributeIndex1() { return attIndex1; } /** * Returns the index of 2nd attribute. * * @return the index of 2nd attribute. */ public int getAttributeIndex2() { return attIndex2; } /** * Returns the attribute indices pair. * * @return the attribute indices pair. */ public IntPair getAttributeIndices() { return new IntPair(attIndex1, attIndex2); } /** * Sets the attribute indices. * * @param attIndex1 the new index for the 1st attribute. * @param attIndex2 the new index for the 2nd attribute. */ public void setAttributeIndices(int attIndex1, int attIndex2) { this.attIndex1 = attIndex1; this.attIndex2 = attIndex2; } /** * Returns the internal prediction matrix. * * @return the internal prediction matrix. */ public double[][] getPredictions() { return predictions; } /** * Sets the prediction matrix. * * @param predictions the new prediction matrix. */ public void setPredictions(double[][] predictions) { this.predictions = predictions; } /** * Returns the prediction array when the 1st attribute is missing. * * @return the prediction array when the 1st attribute is missing. */ public double[] getPredictionsOnMV1() { return predictionsOnMV1; } /** * Sets the prediction array when the 1st attribute is missing. * * @param predictionsOnMV1 the new prediction array. */ public void setPredictionsOnMV1(double[] predictionsOnMV1) { this.predictionsOnMV1 = predictionsOnMV1; } /** * Returns the prediction array when the 2nd attribute is missing. * * @return the prediction array when the 2nd attribute is missing. */ public double[] getPredictionsOnMV2() { return predictionsOnMV2; } /** * Sets the prediction array when the 2nd attribute is missing. * * @param predictionsOnMV2 the new prediction array. */ public void setPredictionsOnMV2(double[] predictionsOnMV2) { this.predictionsOnMV2 = predictionsOnMV2; } /** * Returns the prediction when both attributes are missing. * * @return the prediction when both attributes are missing. */ public double getPredictionOnMV12() { return predictionOnMV12; } /** * Sets the prediction when both attributes are missing. * * @param predictionOnMV12 the new prediction. */ public void setPredictionOnMV12(double predictionOnMV12) { this.predictionOnMV12 = predictionOnMV12; } /** * Multiplies this function with a constant. * * @param c the constant. * @return this function. */ public Function2D multiply(double c) { for (double[] preds : predictions) { VectorUtils.multiply(preds, c); } VectorUtils.multiply(predictionsOnMV1, c); VectorUtils.multiply(predictionsOnMV2, c); predictionOnMV12 *= c; return this; } /** * Divides this function with a constant. * * @param c the constant. * @return this function. */ public Function2D divide(double c) { for (double[] preds : predictions) { VectorUtils.divide(preds, c); } VectorUtils.divide(predictionsOnMV1, c); VectorUtils.divide(predictionsOnMV2, c); predictionOnMV12 /= c; return this; } /** * Adds this function with a constant. * * @param c the constant. * @return this function. */ public Function2D add(double c) { for (double[] preds : predictions) { VectorUtils.add(preds, c); } VectorUtils.add(predictionsOnMV1, c); VectorUtils.add(predictionsOnMV2, c); predictionOnMV12 += c; return this; } /** * Subtracts this function with a constant. * * @param c the constant. * @return this function. */ public Function2D subtract(double c) { for (double[] preds : predictions) { VectorUtils.subtract(preds, c); } VectorUtils.subtract(predictionsOnMV1, c); VectorUtils.subtract(predictionsOnMV2, c); predictionOnMV12 -= c; return this; } /** * Adds this function with another one. * * @param func the other function. * @return this function. */ public Function2D add(Function2D func) { if (attIndex1 != func.attIndex1 || attIndex2 != func.attIndex2) { throw new IllegalArgumentException("Cannot add arrays on differnt terms"); } double[] s1 = splits1; int[] insertionPoints1 = new int[func.splits1.length - 1]; int newElements1 = 0; for (int i = 0; i < insertionPoints1.length; i++) { insertionPoints1[i] = Arrays.binarySearch(splits1, func.splits1[i]); if (insertionPoints1[i] < 0) { newElements1++; } } if (newElements1 > 0) { double[] newSplits1 = new double[splits1.length + newElements1]; System.arraycopy(splits1, 0, newSplits1, 0, splits1.length); int k = splits1.length; for (int i = 0; i < insertionPoints1.length; i++) { if (insertionPoints1[i] < 0) { newSplits1[k++] = func.splits1[i]; } } Arrays.sort(newSplits1); s1 = newSplits1; } double[] s2 = splits2; int[] insertionPoints2 = new int[func.splits2.length - 1]; int newElements2 = 0; for (int j = 0; j < insertionPoints2.length; j++) { insertionPoints2[j] = Arrays.binarySearch(splits2, func.splits2[j]); if (insertionPoints2[j] < 0) { newElements2++; } } if (newElements2 > 0) { double[] newSplits2 = new double[splits2.length + newElements2]; System.arraycopy(splits2, 0, newSplits2, 0, splits2.length); int k = splits2.length; for (int j = 0; j < insertionPoints2.length; j++) { if (insertionPoints2[j] < 0) { newSplits2[k++] = func.splits2[j]; } } Arrays.sort(newSplits2); s2 = newSplits2; } if (newElements1 == 0 && newElements2 == 0) { for (int i = 0; i < splits1.length; i++) { predictionsOnMV2[i] += func.evaluate(splits1[i], Double.NaN); double[] ps = predictions[i]; for (int j = 0; j < splits2.length; j++) { predictionsOnMV1[j] += func.evaluate(Double.NaN, splits2[j]); ps[j] += func.evaluate(splits1[i], splits2[j]); } } predictionOnMV12 += func.predictionOnMV12; } else { double[][] newPredictions = new double[s1.length][s2.length]; predictionsOnMV1 = new double[s2.length]; predictionsOnMV2 = new double[s1.length]; for (int i = 0; i < s1.length; i++) { predictionsOnMV2[i] = this.evaluate(s1[i], Double.NaN) + func.evaluate(s1[i], Double.NaN); double[] ps = newPredictions[i]; for (int j = 0; j < s2.length; j++) { predictionsOnMV1[j] = this.evaluate(Double.NaN, s2[j]) + func.evaluate(Double.NaN, s2[j]); ps[j] = this.evaluate(s1[i], s2[j]) + func.evaluate(s1[i], s2[j]); } } splits1 = s1; splits2 = s2; predictions = newPredictions; predictionOnMV12 = this.predictionOnMV12 + func.predictionOnMV12; } return this; } @Override public void read(BufferedReader in) throws Exception { String line = in.readLine(); String[] data = line.split(": "); attIndex1 = Integer.parseInt(data[1]); line = in.readLine(); data = line.split(": "); attIndex2 = Integer.parseInt(data[1]); in.readLine(); line = in.readLine(); splits1 = ArrayUtils.parseDoubleArray(line); in.readLine(); line = in.readLine(); splits2 = ArrayUtils.parseDoubleArray(line); String[] dim = in.readLine().split(": ")[1].split("x"); predictions = new double[Integer.parseInt(dim[0])][]; for (int i = 0; i < predictions.length; i++) { predictions[i] = ArrayUtils.parseDoubleArray(in.readLine()); } in.readLine(); predictionsOnMV1 = ArrayUtils.parseDoubleArray(in.readLine()); in.readLine(); predictionsOnMV2 = ArrayUtils.parseDoubleArray(in.readLine()); data = in.readLine().split(": "); predictionOnMV12 = Double.parseDouble(data[1]); } @Override public void write(PrintWriter out) throws Exception { out.printf("[Predictor: %s]\n", this.getClass().getCanonicalName()); out.println("AttIndex1: " + attIndex1); out.println("AttIndex2: " + attIndex2); out.println("Splits1: " + splits1.length); out.println(Arrays.toString(splits1)); out.println("Splits2: " + splits2.length); out.println(Arrays.toString(splits2)); out.println("Predictions: " + predictions.length + "x" + predictions[0].length); for (int i = 0; i < predictions.length; i++) { out.println(Arrays.toString(predictions[i])); } out.println("PredictionsOnMV1: " + predictionsOnMV1.length); out.println(Arrays.toString(predictionsOnMV1)); out.println("PredictionsOnMV2: " + predictionsOnMV2.length); out.println(Arrays.toString(predictionsOnMV2)); out.println("PredictionOnMV12: " + predictionOnMV12); } @Override public double regress(Instance instance) { return evaluate(instance.getValue(attIndex1), instance.getValue(attIndex2)); } @Override public double evaluate(double x, double y) { if (!Double.isNaN(x) && !Double.isNaN(y)) { IntPair idx = getSegmentIndex(x, y); return predictions[idx.v1][idx.v2]; } else if (Double.isNaN(x) && !Double.isNaN(y)) { int idx = ArrayUtils.findInsertionPoint(splits2, y); return predictionsOnMV1[idx]; } else if (!Double.isNaN(x) && Double.isNaN(y)) { int idx = ArrayUtils.findInsertionPoint(splits1, x); return predictionsOnMV2[idx]; } else { return predictionOnMV12; } } @Override public Function2D copy() { double[] splits1Copy = Arrays.copyOf(splits1, splits1.length); double[] splits2Copy = Arrays.copyOf(splits2, splits2.length); double[][] predictionsCopy = new double[predictions.length][]; for (int i = 0; i < predictionsCopy.length; i++) { predictionsCopy[i] = Arrays.copyOf(predictions[i], predictions[i].length); } double[] predictionsOnMV1Copy = Arrays.copyOf(predictionsOnMV1, predictionsOnMV1.length); double[] predictionsOnMV2Copy = Arrays.copyOf(predictionsOnMV2, predictionsOnMV2.length); return new Function2D(attIndex1, attIndex2, splits1Copy, splits2Copy, predictionsCopy, predictionsOnMV1Copy, predictionsOnMV2Copy, predictionOnMV12); } /** * Returns the segment indices pair given (x1, x2). Assume x1 and x2 are not missing values. * * @param x1 the 1st search key. * @param x2 the 2nd search key. * @return segment indices pair at (x1, x2). */ protected IntPair getSegmentIndex(double x1, double x2) { int idx1 = ArrayUtils.findInsertionPoint(splits1, x1); int idx2 = ArrayUtils.findInsertionPoint(splits2, x2); return new IntPair(idx1, idx2); } } ================================================ FILE: src/main/java/mltk/predictor/function/Histogram2D.java ================================================ package mltk.predictor.function; import mltk.core.Instance; import mltk.core.Instances; import mltk.util.tuple.Pair; /** * Class for 2D histograms. * * @author Yin Lou * */ public class Histogram2D { public double[][] resp; public double[][] count; public double[] respOnMV1; public double[] countOnMV1; public double[] respOnMV2; public double[] countOnMV2; public double respOnMV12; public double countOnMV12; public static class Table { public double[][][] resp; public double[][][] count; public double[][] respOnMV1; public double[][] countOnMV1; public double[][] respOnMV2; public double[][] countOnMV2; public double respOnMV12; public double countOnMV12; public Table(int n, int m) { resp = new double[n][m][4]; count = new double[n][m][4]; respOnMV1 = new double[m][2]; countOnMV1 = new double[m][2]; respOnMV2 = new double[n][2]; countOnMV2 = new double[n][2]; respOnMV12 = 0.0; countOnMV12 = 0.0; } } /** * Computes 2D histogram given (f1, f2). * * @param instances the data set. * @param f1 the 1st feature. * @param f2 the 2nd feature. * @param hist2d the histogram to compute. */ public static void computeHistogram2D(Instances instances, int f1, int f2, Histogram2D hist2d) { for (Instance instance : instances) { double resp = instance.getTarget() * instance.getWeight(); double weight = instance.getWeight(); if (!instance.isMissing(f1) && !instance.isMissing(f2)) { int idx1 = (int) instance.getValue(f1); int idx2 = (int) instance.getValue(f2); hist2d.resp[idx1][idx2] += resp; hist2d.count[idx1][idx2] += weight; } else if (instance.isMissing(f1) && !instance.isMissing(f2)) { int idx2 = (int) instance.getValue(f2); hist2d.respOnMV1[idx2] += resp; hist2d.countOnMV1[idx2] += weight; } else if (!instance.isMissing(f1) && instance.isMissing(f2)) { int idx1 = (int) instance.getValue(f1); hist2d.respOnMV2[idx1] += resp; hist2d.countOnMV2[idx1] += weight; } else { hist2d.respOnMV12 += resp; hist2d.countOnMV12 += weight; } } } /** * Computes auxiliary data structure given 2D histogram and cumulative 1D histograms. * * @param hist2d the 2D histogram. * @param cHist1 the cumulative histogram for the 1st feature. * @param cHist2 the cumulative histogram for the 2nd feature. * @return table auxiliary data structure. */ public static Table computeTable(Histogram2D hist2d, CHistogram cHist1, CHistogram cHist2) { Table table = new Table(hist2d.resp.length, hist2d.resp[0].length); double sum = 0; double count = 0; for (int j = 0; j < hist2d.resp[0].length; j++) { sum += hist2d.resp[0][j]; table.resp[0][j][0] = sum; count += hist2d.count[0][j]; table.count[0][j][0] = count; fillTable(table, 0, j, cHist1, cHist2); } for (int i = 1; i < hist2d.resp.length; i++) { sum = count = 0; for (int j = 0; j < hist2d.resp[i].length; j++) { sum += hist2d.resp[i][j]; table.resp[i][j][0] = table.resp[i - 1][j][0] + sum; count += hist2d.count[i][j]; table.count[i][j][0] = table.count[i - 1][j][0] + count; fillTable(table, i, j, cHist1, cHist2); } } double respOnMV1 = 0; double countOnMV1 = 0; for (int j = 0; j < hist2d.respOnMV1.length; j++) { respOnMV1 += hist2d.respOnMV1[j]; countOnMV1 += hist2d.countOnMV1[j]; table.respOnMV1[j][0] = respOnMV1; table.respOnMV1[j][1] = cHist1.sumOnMV - respOnMV1; table.countOnMV1[j][0] = countOnMV1; table.countOnMV1[j][1] = cHist1.countOnMV - countOnMV1; } double respOnMV2 = 0; double countOnMV2 = 0; for (int i = 0; i < hist2d.respOnMV2.length; i++) { respOnMV2 += hist2d.respOnMV2[i]; countOnMV2 += hist2d.countOnMV2[i]; table.respOnMV2[i][0] = respOnMV2; table.respOnMV2[i][1] = cHist2.sumOnMV - respOnMV2; table.countOnMV2[i][0] = countOnMV2; table.countOnMV2[i][1] = cHist2.countOnMV - countOnMV2; } table.respOnMV12 = hist2d.respOnMV12; table.countOnMV12 = hist2d.countOnMV12; return table; } protected static void fillTable(Table table, int i, int j, CHistogram cHist1, CHistogram cHist2) { double[] count = table.count[i][j]; double[] resp = table.resp[i][j]; resp[1] = cHist1.sum[i] - resp[0]; resp[2] = cHist2.sum[j] - resp[0]; resp[3] = cHist1.sum[cHist1.size() - 1] - cHist1.sum[i] - resp[2]; count[1] = cHist1.count[i] - count[0]; count[2] = cHist2.count[j] - count[0]; count[3] = cHist1.count[cHist1.size() - 1] - cHist1.count[i] - count[2]; } /** * Constructor. * * @param n the size of the 1st dimension. * @param m the size of the 2nd dimension. */ public Histogram2D(int n, int m) { resp = new double[n][m]; count = new double[n][m]; respOnMV1 = new double[m]; countOnMV1 = new double[m]; respOnMV2 = new double[n]; countOnMV2 = new double[n]; respOnMV12 = 0.0; countOnMV12 = 0.0; } /** * Computes the cumulative histograms on the margin. * * @return the cumulative histograms. */ public Pair computeCHistogram() { CHistogram cHist1 = new CHistogram(resp.length); CHistogram cHist2 = new CHistogram(resp[0].length); for (int i = 0; i < resp.length; i++) { double[] r = resp[i]; double[] c = count[i]; for (int j = 0; j < r.length; j++) { cHist1.sum[i] += r[j]; cHist1.count[i] += c[j]; cHist2.sum[j] += r[j]; cHist2.count[j] += c[j]; } } for (int i = 1; i < cHist1.size(); i++) { cHist1.sum[i] += cHist1.sum[i - 1]; cHist1.count[i] += cHist1.count[i - 1]; } for (int j = 1; j < cHist2.size(); j++) { cHist2.sum[j] += cHist2.sum[j - 1]; cHist2.count[j] += cHist2.count[j - 1]; } for (int j = 0; j < respOnMV1.length; j++) { cHist1.sumOnMV += respOnMV1[j]; cHist1.countOnMV += countOnMV1[j]; } cHist1.sumOnMV += respOnMV12; cHist1.countOnMV += countOnMV12; for (int i = 0; i < respOnMV2.length; i++) { cHist2.sumOnMV += respOnMV2[i]; cHist2.countOnMV += countOnMV2[i]; } cHist2.sumOnMV += respOnMV12; cHist2.countOnMV += countOnMV12; return new Pair(cHist1, cHist2); } } ================================================ FILE: src/main/java/mltk/predictor/function/LineCutter.java ================================================ package mltk.predictor.function; import java.util.ArrayList; import java.util.Collections; import java.util.List; import java.util.PriorityQueue; import mltk.core.Attribute; import mltk.core.BinnedAttribute; import mltk.core.Instance; import mltk.core.Instances; import mltk.core.NominalAttribute; import mltk.predictor.Learner; import mltk.util.Random; import mltk.util.Element; import mltk.util.MathUtils; import mltk.util.OptimUtils; /** * Class for cutting lines. * * @author Yin Lou * */ public class LineCutter extends Learner { static class Interval implements Comparable { boolean finalized; int start; int end; // INF: Can split but split point not computed // NaN: Declared as leaf // Other: Split point double split; double sum; double weight; double value; // mean * sum, or sum * sum /weight; negative gain double gain; Interval left; Interval right; Interval() { split = Double.POSITIVE_INFINITY; } Interval(int start, int end, double sum, double weight) { this.start = start; this.end = end; this.split = Double.POSITIVE_INFINITY; this.sum = sum; this.weight = weight; } @Override public int compareTo(Interval o) { if (this.value < o.value) { return -1; } else if (this.value > o.value) { return 1; } else { return 0; } } double getPrediction() { return MathUtils.divide(sum, weight, 0.0); } boolean isFinalized() { return finalized; } boolean isInteriorNode() { return split < Double.POSITIVE_INFINITY; } boolean isLeaf() { return Double.isNaN(split); } } private int attIndex; private int numIntervals; private boolean isClassification; /** * Constructor. */ public LineCutter() { this(false); } /** * Constructor. * * @param isClassification {@code true} if it is a classification problem. */ public LineCutter(boolean isClassification) { attIndex = -1; this.isClassification = isClassification; } @Override public Function1D build(Instances instances) { return build(instances, attIndex, numIntervals); } /** * Builds a 1D function. * * @param instances the training set. * @param attribute the attribute. * @param numIntervals the number of intervals. * @return a 1D function. */ public Function1D build(Instances instances, Attribute attribute, int numIntervals) { int attIndex = attribute.getIndex(); double sumRespOnMV = 0.0; double sumWeightOnMV = 0.0; List histograms; if (attribute.getType() == Attribute.Type.NUMERIC) { // weight: attribute value // [feature value, sum, weight] List> pairs = new ArrayList<>(instances.size()); for (Instance instance : instances) { double weight = instance.getWeight(); double value = instance.getValue(attIndex); double target = instance.getTarget(); if (!Double.isNaN(value)) { if (isClassification) { pairs.add(new Element<>(new double[] { target, weight }, value)); } else { pairs.add(new Element<>(new double[] { target * weight, weight }, value)); } } else { if (isClassification) { sumRespOnMV += target; } else { sumRespOnMV += target * weight; } sumWeightOnMV += weight; } } Collections.sort(pairs); histograms = new ArrayList<>(pairs.size() + 1); getHistograms(pairs, histograms); histograms.add(new double[] { Double.NaN, sumRespOnMV, sumWeightOnMV }); } else { int size = 0; if (attribute.getType() == Attribute.Type.BINNED) { size = ((BinnedAttribute) attribute).getNumBins(); } else { size = ((NominalAttribute) attribute).getCardinality(); } double[][] histogram = new double[size][2]; for (Instance instance : instances) { double weight = instance.getWeight(); double value = instance.getValue(attIndex); double target = instance.getTarget(); if (!Double.isNaN(value)) { int idx = (int) value; if (isClassification) { histogram[idx][0] += target; } else { histogram[idx][0] += target * weight; } histogram[idx][1] += weight; } else { if (isClassification) { sumRespOnMV += target; } else { sumRespOnMV += target * weight; } sumWeightOnMV += weight; } } histograms = new ArrayList<>(histogram.length + 1); for (int i = 0; i < histogram.length; i++) { if (!MathUtils.isZero(histogram[i][1])) { double[] hist = histogram[i]; histograms.add(new double[] {i, hist[0], hist[1]}); } } histograms.add(new double[] { Double.NaN, sumRespOnMV, sumWeightOnMV }); } return build(attIndex, histograms, numIntervals); } /** * Builds a 1D function. * * @param instances the training set. * @param attIndex the index in the attribute list of the training set. * @param numIntervals the number of intervals. * @return a 1D function. */ public Function1D build(Instances instances, int attIndex, int numIntervals) { Attribute attribute = instances.getAttributes().get(attIndex); return build(instances, attribute, numIntervals); } /** * Returns the index in the attribute list of the training set. * * @return the index in the attribute list of the training set. */ public int getAttributeIndex() { return attIndex; } /** * Sets the index in the attribute list of the training set. * * @param attIndex the attribute index. */ public void setAttributeIndex(int attIndex) { this.attIndex = attIndex; } /** * Returns {@code true} if it is a classification problem. * * @return {@code true} if it is a classification problem. */ public boolean isClassification() { return isClassification; } /** * Sets {@code true} if it is a classification problem. * * @param isClassification {@code true} if it is a classification problem. */ public void setClassification(boolean isClassification) { this.isClassification = isClassification; } /** * Returns the number of intervals. * * @return the number of intervals. */ public int getNumIntervals() { return numIntervals; } /** * Sets the number of intervals. * * @param numIntervals the number of intervals. */ public void setNumIntervals(int numIntervals) { this.numIntervals = numIntervals; } protected static void getHistograms(List> pairs, List histograms) { if (pairs.size() == 0) { return; } // Element(new double[] {sum, weight}, value) double[] hist = pairs.get(0).element; double lastValue = pairs.get(0).weight; double sum = hist[0]; double weight = hist[1]; for (int i = 1; i < pairs.size(); i++) { Element element = pairs.get(i); hist = element.element; double value = element.weight; double s = hist[0]; double w = hist[1]; if (value != lastValue) { histograms.add(new double[] { lastValue, sum, weight }); lastValue = value; sum = s; weight = w; } else { sum += s; weight += w; } } histograms.add(new double[] { lastValue, sum, weight }); } protected static Function1D build(int attIndex, List histograms, int numIntervals) { Function1D func = new Function1D(); func.attIndex = attIndex; // [feature value, sum, weight] double[] histOnMV = histograms.get(histograms.size() - 1); func.predictionOnMV = MathUtils.divide(histOnMV[1], histOnMV[2], 0.0); // 1. Check basic leaf conditions if (histograms.size() <= 2) { func.splits = new double[] { Double.POSITIVE_INFINITY }; double prediction = 0.0; if (histograms.size() == 2) { double[] hist = histograms.get(0); prediction = MathUtils.divide(hist[1], hist[2], 0.0); } func.predictions = new double[] { prediction }; return func; } // 2. Cut the line // 2.1 First cut double[] stats = sumUp(histograms, 0, histograms.size() - 1); Interval root = new Interval(0, histograms.size() - 1, stats[0], stats[1]); split(histograms, root); if (numIntervals == 2) { func.splits = new double[] { root.split, Double.POSITIVE_INFINITY }; func.predictions = new double[] { root.left.getPrediction(), root.right.getPrediction() }; } else if (numIntervals > 2) { PriorityQueue q = new PriorityQueue<>(); if (!root.isLeaf()) { q.add(root); } int numSplits = 0; while (!q.isEmpty()) { Interval parent = q.remove(); parent.finalized = true; split(histograms, parent.left); split(histograms, parent.right); if (!parent.left.isLeaf()) { q.add(parent.left); } if (!parent.right.isLeaf()) { q.add(parent.right); } numSplits++; if (numSplits >= numIntervals - 1) { break; } } List splits = new ArrayList<>(numIntervals - 1); List predictions = new ArrayList<>(numIntervals); inorder(root, splits, predictions); func.splits = new double[predictions.size()]; func.predictions = new double[predictions.size()]; for (int i = 0; i < func.predictions.length; i++) { func.predictions[i] = predictions.get(i); } for (int i = 0; i < func.splits.length - 1; i++) { func.splits[i] = splits.get(i); } func.splits[func.splits.length - 1] = Double.POSITIVE_INFINITY; } return func; } protected static double[] sumUp(List histograms, int start, int end) { double sum = 0; double weight = 0; for (int i = start; i < end; i++) { double[] hist = histograms.get(i); sum += hist[1]; weight += hist[2]; } return new double[] { sum, weight }; } protected static void split(List histograms, Interval parent) { split(histograms, parent, 5); } protected static void split(List histograms, Interval parent, double limit) { // Test if we need to split if (parent.weight <= limit || parent.end - parent.start <= 1) { parent.split = Double.NaN; // Declared as leaf } else { parent.left = new Interval(); parent.right = new Interval(); int start = parent.left.start = parent.start; int end = parent.right.end = parent.end; final double sum = parent.sum; final double totalWeight = parent.weight; double sum1 = histograms.get(start)[1]; double sum2 = sum - sum1; double weight1 = histograms.get(start)[2]; double weight2 = totalWeight - weight1; double bestEval = -(OptimUtils.getGain(sum1, weight1) + OptimUtils.getGain(sum2, weight2)); List splits = new ArrayList<>(); splits.add(new double[] { (histograms.get(start)[0] + histograms.get(start + 1)[0]) / 2, start, sum1, weight1, sum2, weight2 }); for (int i = start + 1; i < end - 1; i++) { double[] hist = histograms.get(i); final double s = hist[1]; final double w = hist[2]; sum1 += s; sum2 -= s; weight1 += w; weight2 -= w; double eval1 = OptimUtils.getGain(sum1, weight1); double eval2 = OptimUtils.getGain(sum2, weight2); double eval = -(eval1 + eval2); if (eval <= bestEval) { double split = (histograms.get(i)[0] + histograms.get(i + 1)[0]) / 2; if (eval < bestEval) { bestEval = eval; splits.clear(); } splits.add(new double[] { split, i, sum1, weight1, sum2, weight2 }); } } Random rand = Random.getInstance(); double[] split = splits.get(rand.nextInt(splits.size())); parent.split = split[0]; parent.left.end = (int) split[1] + 1; parent.right.start = (int) split[1] + 1; parent.left.sum = split[2]; parent.left.weight = split[3]; parent.right.sum = split[4]; parent.right.weight = split[5]; parent.gain = OptimUtils.getGain(parent.left.sum, parent.left.weight) + OptimUtils.getGain(parent.right.sum, parent.right.weight); parent.value = -parent.gain + (sum / totalWeight * sum); } } protected static void inorder(Interval parent, List splits, List predictions) { if (parent.isFinalized()) { inorder(parent.left, splits, predictions); splits.add(parent.split); inorder(parent.right, splits, predictions); } else { predictions.add(parent.getPrediction()); } } } ================================================ FILE: src/main/java/mltk/predictor/function/LinearFunction.java ================================================ package mltk.predictor.function; import java.io.BufferedReader; import java.io.PrintWriter; import mltk.core.Instance; import mltk.predictor.Regressor; /** * Class for linear functions. * * @author Yin Lou * */ public class LinearFunction implements Regressor, UnivariateFunction { /** * The attribute index. */ protected int attIndex; /** * The slope. */ protected double beta; /** * Constructor. */ public LinearFunction() { } /** * Constructs a linear function with a provided slope value. * * @param beta the slope. */ public LinearFunction(double beta) { this(-1, beta); } /** * Constructs a linear function with a provided slope value and attribute index. * * @param attIndex the attribute index. * @param beta the slope. */ public LinearFunction(int attIndex, double beta) { this.attIndex = attIndex; this.beta = beta; } @Override public void read(BufferedReader in) throws Exception { attIndex = Integer.parseInt(in.readLine().split(": ")[1]); beta = Double.parseDouble(in.readLine().split(": ")[1]); } @Override public void write(PrintWriter out) throws Exception { out.printf("[Predictor: %s]\n", this.getClass().getCanonicalName()); out.println("AttIndex: " + attIndex); out.println("Beta: " + beta); } @Override public double evaluate(double x) { return beta * x; } @Override public double regress(Instance instance) { return evaluate(instance.getValue(attIndex)); } public double getSlope() { return beta; } public void setSlope(double beta) { this.beta = beta; } public int getAttributeIndex() { return attIndex; } @Override public LinearFunction copy() { return new LinearFunction(attIndex, beta); } } ================================================ FILE: src/main/java/mltk/predictor/function/SquareCutter.java ================================================ package mltk.predictor.function; import java.util.List; import mltk.core.Attribute; import mltk.core.BinnedAttribute; import mltk.core.Instance; import mltk.core.Instances; import mltk.core.NominalAttribute; import mltk.predictor.Learner; import mltk.util.MathUtils; import mltk.util.OptimUtils; import mltk.util.tuple.Pair; /** * Class for cutting squares. * * @author Yin Lou * */ public class SquareCutter extends Learner { private int attIndex1; private int attIndex2; private boolean lineSearch; /** * Constructor. */ public SquareCutter() { } /** * Constructor. * * @param lineSearch {@code true} if line search is performed in the end. */ public SquareCutter(boolean lineSearch) { this.lineSearch = lineSearch; } /** * Sets the attribute indices. * * @param attIndex1 the 1st index of attribute. * @param attIndex2 the 2nd index of attribute. */ public void setAttIndices(int attIndex1, int attIndex2) { this.attIndex1 = attIndex1; this.attIndex2 = attIndex2; } public Function2D build(Instances instances) { // Note: Currently only support cutting on binned/nominal features List attributes = instances.getAttributes(); int size1 = 0; Attribute f1 = attributes.get(attIndex1); if (f1.getType() == Attribute.Type.BINNED) { size1 = ((BinnedAttribute) f1).getNumBins(); } else if (f1.getType() == Attribute.Type.NOMINAL) { size1 = ((NominalAttribute) f1).getCardinality(); } int size2 = 0; Attribute f2 = attributes.get(attIndex2); if (f2.getType() == Attribute.Type.BINNED) { size2 = ((BinnedAttribute) f2).getNumBins(); } else if (f2.getType() == Attribute.Type.NOMINAL) { size2 = ((NominalAttribute) f2).getCardinality(); } Histogram2D hist2d = new Histogram2D(size1, size2); Histogram2D.computeHistogram2D(instances, f1.getIndex(), f2.getIndex(), hist2d); Pair cHist = hist2d.computeCHistogram(); if ((size1 == 1 && !cHist.v1.hasMissingValue()) || (size2 == 1 && !cHist.v2.hasMissingValue())) { // Not an interaction // Recommend: Use LineCutter to shape the non-trivial attribute return new Function2D(f1.getIndex(), f2.getIndex(), new double[] { Double.POSITIVE_INFINITY }, new double[] { Double.POSITIVE_INFINITY }, new double[1][1]); } Histogram2D.Table table = Histogram2D.computeTable(hist2d, cHist.v1, cHist.v2); double bestRSS = Double.POSITIVE_INFINITY; double[] predInt1 = new double[9]; int bestV1 = -1; int[] bestV2s = new int[3]; int[] v2s = new int[3]; v2s[2] = -1; for (int v1 = 0; v1 < size1 - 1; v1++) { findCuts(table, v1, v2s, cHist.v1.hasMissingValue()); getPredictor(table, v1, v2s, predInt1); double rss = getRSS(table, v1, v2s, predInt1); if (rss < bestRSS) { bestRSS = rss; bestV1 = v1; bestV2s[0] = v2s[0]; bestV2s[1] = v2s[1]; bestV2s[2] = v2s[2]; } } boolean cutOnAttr2 = false; double[] predInt2 = new double[9]; int[] bestV1s = new int[3]; int bestV2 = -1; int[] v1s = new int[3]; v1s[2] = -1; for (int v2 = 0; v2 < size2 - 1; v2++) { findCuts(table, v1s, v2, cHist.v2.hasMissingValue()); getPredictor(table, v1s, v2, predInt2); double rss = getRSS(table, v1s, v2, predInt2); if (rss < bestRSS) { bestRSS = rss; bestV2 = v2; bestV1s[0] = v1s[0]; bestV1s[1] = v1s[1]; bestV1s[2] = v1s[2]; cutOnAttr2 = true; } } if (cutOnAttr2) { // Root cut on attribute 2 is better getPredictor(table, bestV1s, bestV2, predInt2); if (lineSearch) { lineSearch(instances, f2.getIndex(), f1.getIndex(), bestV2, bestV1s[0], bestV1s[1], bestV1s[2], predInt2); } return getFunction2D(f1.getIndex(), f2.getIndex(), bestV1s, bestV2, predInt2); } else { // Root cut on attribute 1 is better getPredictor(table, bestV1, bestV2s, predInt1); if (lineSearch) { lineSearch(instances, f1.getIndex(), f2.getIndex(), bestV1, bestV2s[0], bestV2s[1], bestV2s[2], predInt1); } return getFunction2D(f1.getIndex(), f2.getIndex(), bestV1, bestV2s, predInt1); } } protected static void findCuts(Histogram2D.Table table, int v1, int[] v2, boolean hasMissingValue) { // Find upper cut double bestEval = Double.POSITIVE_INFINITY; for (int i = 0; i < table.resp[v1].length - 1; i++) { double[] resp = table.resp[v1][i]; double[] count = table.count[v1][i]; double sum1 = resp[0]; double sum2 = resp[1]; double weight1 = count[0]; double weight2 = count[1]; double eval1 = OptimUtils.getGain(sum1, weight1); double eval2 = OptimUtils.getGain(sum2, weight2); double eval = -(eval1 + eval2); if (eval < bestEval) { bestEval = eval; v2[0] = i; } } // Find lower cut bestEval = Double.POSITIVE_INFINITY; for (int i = 0; i < table.resp[v1].length - 1; i++) { double[] resp = table.resp[v1][i]; double[] count = table.count[v1][i]; double sum1 = resp[2]; double sum2 = resp[3]; double weight1 = count[2]; double weight2 = count[3]; double eval1 = OptimUtils.getGain(sum1, weight1); double eval2 = OptimUtils.getGain(sum2, weight2); double eval = -(eval1 + eval2); if (eval < bestEval) { bestEval = eval; v2[1] = i; } } if (hasMissingValue) { // Find cut on missing value bestEval = Double.POSITIVE_INFINITY; for (int i = 0; i < table.respOnMV1.length; i++) { double[] respOnMV1 = table.respOnMV1[i]; double[] countOnMV1 = table.countOnMV1[i]; double sum1 = respOnMV1[0]; double sum2 = respOnMV1[1]; double weight1 = countOnMV1[0]; double weight2 = countOnMV1[1]; double eval1 = OptimUtils.getGain(sum1, weight1); double eval2 = OptimUtils.getGain(sum2, weight2); double eval = -(eval1 + eval2); if (eval < bestEval) { bestEval = eval; v2[2] = i; } } } } protected static void findCuts(Histogram2D.Table table, int[] v1, int v2, boolean hasMissingValue) { // Find left cut double bestEval = Double.POSITIVE_INFINITY; for (int i = 0; i < table.resp.length - 1; i++) { double[] resp = table.resp[i][v2]; double[] count = table.count[i][v2]; double sum1 = resp[0]; double sum2 = resp[2]; double weight1 = count[0]; double weight2 = count[2]; double eval1 = OptimUtils.getGain(sum1, weight1); double eval2 = OptimUtils.getGain(sum2, weight2); double eval = -(eval1 + eval2); if (eval < bestEval) { bestEval = eval; v1[0] = i; } } // Find right cut bestEval = Double.POSITIVE_INFINITY; for (int i = 0; i < table.resp.length - 1; i++) { double[] resp = table.resp[i][v2]; double[] count = table.count[i][v2]; double sum1 = resp[1]; double sum2 = resp[3]; double weight1 = count[1]; double weight2 = count[3]; double eval1 = OptimUtils.getGain(sum1, weight1); double eval2 = OptimUtils.getGain(sum2, weight2); double eval = -(eval1 + eval2); if (eval < bestEval) { bestEval = eval; v1[1] = i; } } if (hasMissingValue) { // Find cut on missing value bestEval = Double.POSITIVE_INFINITY; for (int i = 0; i < table.respOnMV2.length; i++) { double[] respOnMV2 = table.respOnMV2[i]; double[] countOnMV2 = table.countOnMV2[i]; double sum1 = respOnMV2[0]; double sum2 = respOnMV2[1]; double weight1 = countOnMV2[0]; double weight2 = countOnMV2[1]; double eval1 = OptimUtils.getGain(sum1, weight1); double eval2 = OptimUtils.getGain(sum2, weight2); double eval = -(eval1 + eval2); if (eval < bestEval) { bestEval = eval; v1[2] = i; } } } } protected static void getPredictor(Histogram2D.Table table, int v1, int[] v2, double[] pred) { int v21 = v2[0]; int v22 = v2[1]; int vMV = v2[2]; double[] resp1 = table.resp[v1][v21]; double[] count1 = table.count[v1][v21]; double[] resp2 = table.resp[v1][v22]; double[] count2 = table.count[v1][v22]; pred[0] = MathUtils.divide(resp1[0], count1[0], 0); pred[1] = MathUtils.divide(resp1[1], count1[1], 0); pred[2] = MathUtils.divide(resp2[2], count2[2], 0); pred[3] = MathUtils.divide(resp2[3], count2[3], 0); if (vMV >= 0) { double[] respOnMV1 = table.respOnMV1[vMV]; double[] countOnMV1 = table.countOnMV1[vMV]; pred[4] = MathUtils.divide(respOnMV1[0], countOnMV1[0], 0); pred[5] = MathUtils.divide(respOnMV1[1], countOnMV1[1], 0); } double[] respOnMV2 = table.respOnMV2[v1]; double[] countOnMV2 = table.countOnMV2[v1]; pred[6] = MathUtils.divide(respOnMV2[0], countOnMV2[0], 0); pred[7] = MathUtils.divide(respOnMV2[1], countOnMV2[1], 0); pred[8] = MathUtils.divide(table.respOnMV12, table.countOnMV12, 0); } protected static void getPredictor(Histogram2D.Table table, int[] v1, int v2, double[] pred) { int v11 = v1[0]; int v12 = v1[1]; int vMV = v1[2]; double[] resp1 = table.resp[v11][v2]; double[] count1 = table.count[v11][v2]; double[] resp2 = table.resp[v12][v2]; double[] count2 = table.count[v12][v2]; pred[0] = MathUtils.divide(resp1[0], count1[0], 0); pred[1] = MathUtils.divide(resp1[2], count1[2], 0); pred[2] = MathUtils.divide(resp2[1], count2[1], 0); pred[3] = MathUtils.divide(resp2[3], count2[3], 0); if (vMV >= 0) { double[] respOnMV2 = table.respOnMV2[vMV]; double[] countOnMV2 = table.countOnMV2[vMV]; pred[4] = MathUtils.divide(respOnMV2[0], countOnMV2[0], 0); pred[5] = MathUtils.divide(respOnMV2[1], countOnMV2[1], 0); } double[] respOnMV1 = table.respOnMV1[v2]; double[] countOnMV1 = table.countOnMV1[v2]; pred[6] = MathUtils.divide(respOnMV1[0], countOnMV1[0], 0); pred[7] = MathUtils.divide(respOnMV1[1], countOnMV1[1], 0); pred[8] = MathUtils.divide(table.respOnMV12, table.countOnMV12, 0); } protected static double getRSS(Histogram2D.Table table, int v1, int v2[], double[] pred) { int v21 = v2[0]; int v22 = v2[1]; int vMV = v2[2]; double[] resp1 = table.resp[v1][v21]; double[] resp2 = table.resp[v1][v22]; double[] count1 = table.count[v1][v21]; double[] count2 = table.count[v1][v22]; double rss = 0; rss += pred[0] * pred[0] * count1[0]; rss += pred[1] * pred[1] * count1[1]; rss += pred[2] * pred[2] * count2[2]; rss += pred[3] * pred[3] * count2[3]; if (vMV >= 0) { double[] countOnMV1 = table.countOnMV1[vMV]; rss += pred[4] * pred[4] * countOnMV1[0]; rss += pred[5] * pred[5] * countOnMV1[1]; } double[] countOnMV2 = table.countOnMV2[v1]; rss += pred[6] * pred[6] * countOnMV2[0]; rss += pred[7] * pred[7] * countOnMV2[1]; rss += pred[8] * pred[8] * table.countOnMV12; double t = 0; t += pred[0] * resp1[0]; t += pred[1] * resp1[1]; t += pred[2] * resp2[2]; t += pred[3] * resp2[3]; if (vMV >= 0) { double[] respOnMV1 = table.respOnMV1[vMV]; t += pred[4] * respOnMV1[0]; t += pred[5] * respOnMV1[1]; } double[] respOnMV2 = table.respOnMV2[v1]; t += pred[6] * respOnMV2[0]; t += pred[7] * respOnMV2[1]; t += pred[8] * table.respOnMV12; rss -= 2 * t; return rss; } protected static double getRSS(Histogram2D.Table table, int[] v1, int v2, double[] pred) { int v11 = v1[0]; int v12 = v1[1]; int vMV = v1[2]; double[] resp1 = table.resp[v11][v2]; double[] resp2 = table.resp[v12][v2]; double[] count1 = table.count[v11][v2]; double[] count2 = table.count[v12][v2]; double rss = 0; rss += pred[0] * pred[0] * count1[0]; rss += pred[1] * pred[1] * count1[2]; rss += pred[2] * pred[2] * count2[1]; rss += pred[3] * pred[3] * count2[3]; if (vMV >= 0) { double[] countOnMV2 = table.countOnMV2[vMV]; rss += pred[4] * pred[4] * countOnMV2[0]; rss += pred[5] * pred[5] * countOnMV2[1]; } double[] countOnMV1 = table.countOnMV1[v2]; rss += pred[6] * pred[6] * countOnMV1[0]; rss += pred[7] * pred[7] * countOnMV1[1]; rss += pred[8] * pred[8] * table.countOnMV12; double t = 0; t += pred[0] * resp1[0]; t += pred[1] * resp1[2]; t += pred[2] * resp2[1]; t += pred[3] * resp2[3]; if (vMV >= 0) { double[] respOnMV2 = table.respOnMV2[vMV]; t += pred[4] * respOnMV2[0]; t += pred[5] * respOnMV2[1]; } double[] respOnMV1 = table.respOnMV1[v2]; t += pred[6] * respOnMV1[0]; t += pred[7] * respOnMV1[1]; t += pred[8] * table.respOnMV12; rss -= 2 * t; return rss; } protected static Function2D getFunction2D(int attIndex1, int attIndex2, int v1, int[] v2, double[] predInt) { double[] splits1 = new double[] { v1, Double.POSITIVE_INFINITY }; double[] splits2 = null; double[][] predictions = null; double[] predictionsOnMV1 = null; double[] predictionsOnMV2 = new double[] {predInt[6], predInt[7]}; if (v2[0] < v2[1]) { if (v2[2] < 0 || v2[2] == v2[0] || v2[2] == v2[1]) { splits2 = new double[] { v2[0], v2[1], Double.POSITIVE_INFINITY }; predictions = new double[][] { { predInt[0], predInt[1], predInt[1] }, { predInt[2], predInt[2], predInt[3] } }; if (v2[2] < 0) { predictionsOnMV1 = new double[] { 0.0, 0.0, 0.0 }; } else if (v2[2] == v2[0]) { predictionsOnMV1 = new double[] { predInt[4], predInt[5], predInt[5] }; } else { predictionsOnMV1 = new double[] { predInt[4], predInt[4], predInt[5] }; } } else { if (v2[2] < v2[0]) { splits2 = new double[] { v2[2], v2[0], v2[1], Double.POSITIVE_INFINITY }; predictions = new double[][] { { predInt[0], predInt[0], predInt[1], predInt[1] }, { predInt[2], predInt[2], predInt[2], predInt[3] } }; predictionsOnMV1 = new double[] { predInt[4], predInt[5], predInt[5], predInt[5] }; } else if (v2[2] < v2[1]) { splits2 = new double[] { v2[0], v2[2], v2[1], Double.POSITIVE_INFINITY }; predictions = new double[][] { { predInt[0], predInt[1], predInt[1], predInt[1] }, { predInt[2], predInt[2], predInt[2], predInt[3] } }; predictionsOnMV1 = new double[] { predInt[4], predInt[4], predInt[5], predInt[5] }; } else { splits2 = new double[] { v2[0], v2[1], v2[2], Double.POSITIVE_INFINITY }; predictions = new double[][] { { predInt[0], predInt[1], predInt[1], predInt[1] }, { predInt[2], predInt[2], predInt[3], predInt[3] } }; predictionsOnMV1 = new double[] { predInt[4], predInt[4], predInt[4], predInt[5] }; } } } else if (v2[0] > v2[1]) { if (v2[2] < 0 || v2[2] == v2[1] || v2[2] == v2[1]) { splits2 = new double[] { v2[1], v2[0], Double.POSITIVE_INFINITY }; predictions = new double[][] { { predInt[0], predInt[0], predInt[1] }, { predInt[2], predInt[3], predInt[3] } }; if (v2[2] < 0) { predictionsOnMV1 = new double[] { 0.0, 0.0, 0.0 }; } else if (v2[2] == v2[1]) { predictionsOnMV1 = new double[] { predInt[4], predInt[5], predInt[5] }; } else { predictionsOnMV1 = new double[] { predInt[4], predInt[4], predInt[5] }; } } else { if (v2[2] < v2[1]) { splits2 = new double[] { v2[2], v2[1], v2[0], Double.POSITIVE_INFINITY }; predictions = new double[][] { { predInt[0], predInt[0], predInt[0], predInt[1] }, { predInt[2], predInt[2], predInt[3], predInt[3] } }; predictionsOnMV1 = new double[] { predInt[4], predInt[5], predInt[5], predInt[5] }; } else if (v2[2] < v2[0]) { splits2 = new double[] { v2[1], v2[2], v2[0], Double.POSITIVE_INFINITY }; predictions = new double[][] { { predInt[0], predInt[0], predInt[0], predInt[1] }, { predInt[2], predInt[3], predInt[3], predInt[3] } }; predictionsOnMV1 = new double[] { predInt[4], predInt[4], predInt[5], predInt[5] }; } else { splits2 = new double[] { v2[1], v2[0], v2[2], Double.POSITIVE_INFINITY }; predictions = new double[][] { { predInt[0], predInt[0], predInt[1], predInt[1] }, { predInt[2], predInt[3], predInt[3], predInt[3] } }; predictionsOnMV1 = new double[] { predInt[4], predInt[4], predInt[4], predInt[5] }; } } } else { // v2[0] == v2[1] if (v2[2] < 0 || v2[2] == v2[0]) { splits2 = new double[] { v2[0], Double.POSITIVE_INFINITY }; predictions = new double[][] { { predInt[0], predInt[1] }, { predInt[2], predInt[3] } }; if (v2[2] < 0) { predictionsOnMV1 = new double[] { 0.0, 0.0 }; } else { predictionsOnMV1 = new double[] { predInt[4], predInt[5] }; } } else { if (v2[2] < v2[0]) { splits2 = new double[] { v2[2], v2[0], Double.POSITIVE_INFINITY }; predictions = new double[][] { { predInt[0], predInt[0], predInt[1] }, { predInt[2], predInt[2], predInt[3] } }; predictionsOnMV1 = new double[] { predInt[4], predInt[5], predInt[5] }; } else { splits2 = new double[] { v2[0], v2[2], Double.POSITIVE_INFINITY }; predictions = new double[][] { { predInt[0], predInt[1], predInt[1] }, { predInt[2], predInt[3], predInt[3] } }; predictionsOnMV1 = new double[] { predInt[4], predInt[4], predInt[5] }; } } } return new Function2D(attIndex1, attIndex2, splits1, splits2, predictions, predictionsOnMV1, predictionsOnMV2, predInt[8]); } protected static Function2D getFunction2D(int attIndex1, int attIndex2, int[] v1, int v2, double[] predInt) { double[] splits1 = null; double[] splits2 = new double[] { v2, Double.POSITIVE_INFINITY }; double[] predictionsOnMV1 = new double[] {predInt[6], predInt[7]}; double[] predictionsOnMV2 = null; double[][] predictions = null; if (v1[0] < v1[1]) { if (v1[2] < 0 || v1[2] == v1[0] || v1[2] == v1[1]) { splits1 = new double[] { v1[0], v1[1], Double.POSITIVE_INFINITY }; predictions = new double[][] { { predInt[0], predInt[2] }, { predInt[1], predInt[2] }, { predInt[1], predInt[3] } }; if (v1[2] < 0) { predictionsOnMV2 = new double[] { 0.0, 0.0, 0.0 }; } else if (v1[2] == v1[0]) { predictionsOnMV2 = new double[] { predInt[4], predInt[5], predInt[5] }; } else { predictionsOnMV2 = new double[] { predInt[4], predInt[4], predInt[5] }; } } else { if (v1[2] < v1[0]) { splits1 = new double[] { v1[2], v1[0], v1[1], Double.POSITIVE_INFINITY }; predictions = new double[][] { { predInt[0], predInt[2] }, { predInt[0], predInt[2] }, { predInt[1], predInt[2] }, { predInt[1], predInt[3] } }; predictionsOnMV2 = new double[] { predInt[4], predInt[5], predInt[5], predInt[5] }; } else if (v1[2] < v1[1]) { splits1 = new double[] { v1[0], v1[2], v1[1], Double.POSITIVE_INFINITY }; predictions = new double[][] { { predInt[0], predInt[2] }, { predInt[1], predInt[2] }, { predInt[1], predInt[2] }, { predInt[1], predInt[3] } }; predictionsOnMV2 = new double[] { predInt[4], predInt[4], predInt[5], predInt[5] }; } else { splits1 = new double[] { v1[0], v1[1], v1[2], Double.POSITIVE_INFINITY }; predictions = new double[][] { { predInt[0], predInt[2] }, { predInt[1], predInt[2] }, { predInt[1], predInt[3] }, { predInt[1], predInt[3] } }; predictionsOnMV2 = new double[] { predInt[4], predInt[4], predInt[4], predInt[5] }; } } } else if (v1[0] > v1[1]) { if (v1[2] < 0 || v1[2] == v1[0] || v1[2] == v1[1]) { splits1 = new double[] { v1[1], v1[0], Double.POSITIVE_INFINITY }; predictions = new double[][] { { predInt[0], predInt[2] }, { predInt[0], predInt[3] }, { predInt[1], predInt[3] } }; if (v1[2] < 0) { predictionsOnMV2 = new double[] { 0.0, 0.0, 0.0 }; } else if (v1[2] == v1[0]) { predictionsOnMV2 = new double[] { predInt[4], predInt[4], predInt[5] }; } else { predictionsOnMV2 = new double[] { predInt[4], predInt[5], predInt[5] }; } } else { if (v1[2] < v1[1]) { splits1 = new double[] { v1[2], v1[1], v1[0], Double.POSITIVE_INFINITY }; predictions = new double[][] { { predInt[0], predInt[2] }, { predInt[0], predInt[2] }, { predInt[0], predInt[3] }, { predInt[1], predInt[3] } }; predictionsOnMV2 = new double[] { predInt[4], predInt[5], predInt[5], predInt[5] }; } else if (v1[2] < v1[0]) { splits1 = new double[] { v1[1], v1[2], v1[0], Double.POSITIVE_INFINITY }; predictions = new double[][] { { predInt[0], predInt[2] }, { predInt[0], predInt[3] }, { predInt[0], predInt[3] }, { predInt[1], predInt[3] } }; predictionsOnMV2 = new double[] { predInt[4], predInt[4], predInt[5], predInt[5] }; } else { splits1 = new double[] { v1[1], v1[0], v1[2], Double.POSITIVE_INFINITY }; predictions = new double[][] { { predInt[0], predInt[2] }, { predInt[0], predInt[2] }, { predInt[0], predInt[3] }, { predInt[1], predInt[3] } }; predictionsOnMV2 = new double[] { predInt[4], predInt[4], predInt[4], predInt[5] }; } } } else { // v1[0] == v1[1] if (v1[2] < 0 || v1[2] == v1[0]) { splits1 = new double[] { v1[0], Double.POSITIVE_INFINITY }; predictions = new double[][] { { predInt[0], predInt[2] }, { predInt[1], predInt[3] } }; if (v1[2] < 0) { predictionsOnMV2 = new double[] { 0.0, 0.0 }; } else { predictionsOnMV2 = new double[] { predInt[4], predInt[5] }; } } else { if (v1[2] < v1[0]) { splits1 = new double[] { v1[2], v1[0], Double.POSITIVE_INFINITY }; predictions = new double[][] { { predInt[0], predInt[2] }, { predInt[0], predInt[2] }, { predInt[1], predInt[3] } }; predictionsOnMV2 = new double[] { predInt[4], predInt[5], predInt[5] }; } else { splits1 = new double[] { v1[2], v1[0], Double.POSITIVE_INFINITY }; predictions = new double[][] { { predInt[0], predInt[2] }, { predInt[1], predInt[3] }, { predInt[1], predInt[3] } }; predictionsOnMV2 = new double[] { predInt[4], predInt[4], predInt[5] }; } } } return new Function2D(attIndex1, attIndex2, splits1, splits2, predictions, predictionsOnMV1, predictionsOnMV2, predInt[8]); } protected static void lineSearch(Instances instances, int attIndex1, int attIndex2, int c1, int c21, int c22, int cMV, double[] predictions) { double[] numerator = new double[9]; double[] denominator = new double[9]; for (Instance instance : instances) { final double target = instance.getTarget(); final double weight = instance.getWeight(); final double t = Math.abs(target); final double num = target * weight; final double den = t * (1 - t) * weight; if (!instance.isMissing(attIndex1) && !instance.isMissing(attIndex2)) { int v1 = (int) instance.getValue(attIndex1); int v2 = (int) instance.getValue(attIndex2); if (v1 <= c1) { if (v2 <= c21) { numerator[0] += num; denominator[0] += den; } else { numerator[1] += num; denominator[1] += den; } } else { if (v2 <= c22) { numerator[2] += num; denominator[2] += den; } else { numerator[3] += num; denominator[3] += den; } } } else if (instance.isMissing(attIndex1) && !instance.isMissing(attIndex2)) { int v2 = (int) instance.getValue(attIndex2); if (cMV >= 0) { if (v2 <= cMV) { numerator[4] += num; denominator[4] += den; } else { numerator[5] += num; denominator[5] += den; } } else { throw new RuntimeException("Something went wrong"); } } else if (!instance.isMissing(attIndex1) && instance.isMissing(attIndex2)) { int v1 = (int) instance.getValue(attIndex1); if (v1 <= c1) { numerator[6] += num; denominator[6] += den; } else { numerator[7] += num; denominator[7] += den; } } else { numerator[8] += num; denominator[8] += den; } } for (int i = 0; i < numerator.length; i++) { predictions[i] = MathUtils.divide(numerator[i], denominator[i], 0); } } } ================================================ FILE: src/main/java/mltk/predictor/function/SubagSequence.java ================================================ package mltk.predictor.function; import java.util.ArrayList; import java.util.HashMap; import java.util.HashSet; import java.util.List; import java.util.Map; import java.util.PriorityQueue; import java.util.Set; import mltk.util.Element; import mltk.util.Permutation; import mltk.util.Queue; import mltk.util.UFSets; import mltk.util.tuple.IntPair; class SubagSequence { static class SampleDelta { int[] toAdd; int[] toDel; SampleDelta(Set toAdd, Set toDel) { this.toAdd = new int[toAdd.size()]; int k = 0; for (Integer idx : toAdd) { this.toAdd[k++] = idx; } k = 0; this.toDel = new int[toDel.size()]; for (Integer idx : toDel) { this.toDel[k++] = idx; } } int getDistance() { return toAdd.length + toDel.length; } } static class Sample { Set set; int[] indices; Sample(Set set) { this.set = set; this.indices = new int[set.size()]; int k = 0; for (Integer idx : set) { this.indices[k++] = idx; } } static int computeDistance(Sample s1, Sample s2) { SampleDelta delta = computeDelta(s1, s2); return delta.getDistance(); } static SampleDelta computeDelta(Sample s1, Sample s2) { Set toAdd = new HashSet<>(s2.set); toAdd.removeAll(s1.set); Set toDel = new HashSet<>(s1.set); toDel.removeAll(s2.set); return new SampleDelta(toAdd, toDel); } int getWeight() { return set.size(); } } Sample[] samples; int[] start; int[] end; SampleDelta[] deltas; int[] count; SubagSequence(int n, int m, int baggingIters) { samples = new Sample[baggingIters]; count = new int[baggingIters]; for (int i = 0; i < samples.length; i++) { samples[i] = createSubsample(n, m); } computeSequence(samples); } Sample createSubsample(int n, int m) { Permutation perm = new Permutation(n); perm.permute(); int[] a = perm.getPermutation(); Set set = new HashSet<>(m); for (int i = 0; i < m; i++) { set.add(a[i]); } return new Sample(set); } private void computeSequence(Sample[] samples) { PriorityQueue> q = new PriorityQueue<>(samples.length * (samples.length - 1) / 2); for (int i = 0; i < samples.length - 1; i++) { for (int j = i + 1; j < samples.length; j++) { int distance = Sample.computeDistance(samples[i], samples[j]); q.add(new Element(new IntPair(i, j), distance)); } } Map> map = new HashMap<>(); UFSets ufsets = new UFSets(samples.length); while (!q.isEmpty()) { Element e = q.poll(); int x = e.element.v1; int y = e.element.v2; int root1 = ufsets.find(x); int root2 = ufsets.find(y); if (root1 != root2) { ufsets.union(root1, root2); if (!map.containsKey(x)) { map.put(x, new HashSet<>()); } map.get(x).add(y); if (!map.containsKey(y)) { map.put(y, new HashSet<>()); } map.get(y).add(x); } } int s = 0; for (int i = 1; i < samples.length; i++) { int weight = samples[i].getWeight(); if (weight < samples[s].getWeight()) { s = i; } } List fromList = new ArrayList<>(samples.length); List toList = new ArrayList<>(samples.length); List deltas = new ArrayList<>(samples.length); // BFS Set covered = new HashSet<>(); covered.add(s); Queue queue = new Queue<>(); queue.enqueue(s); while (covered.size() < samples.length) { Integer node = queue.dequeue(); Set children = map.get(node); for (Integer child : children) { if (covered.contains(child)) { continue; } covered.add(child); fromList.add(node); toList.add(child); SampleDelta delta = Sample.computeDelta(samples[node], samples[child]); deltas.add(delta); queue.enqueue(child); } } this.start = new int[fromList.size()]; this.end = new int[toList.size()]; this.deltas = new SampleDelta[deltas.size()]; for (int i = 0; i < this.start.length; i++) { this.start[i] = fromList.get(i); this.end[i] = toList.get(i); this.deltas[i] = deltas.get(i); } for (int i = 0; i < this.start.length; i++) { count[this.start[i]]++; } } } ================================================ FILE: src/main/java/mltk/predictor/function/SubaggedLineCutter.java ================================================ package mltk.predictor.function; import java.util.ArrayList; import java.util.Arrays; import java.util.List; import mltk.core.Attribute; import mltk.core.BinnedAttribute; import mltk.core.Instance; import mltk.core.Instances; import mltk.core.NominalAttribute; import mltk.predictor.BaggedEnsemble; import mltk.predictor.function.SubagSequence.SampleDelta; import mltk.util.MathUtils; /** * Class for cutting lines with subagging. * * @author Yin Lou * */ public class SubaggedLineCutter extends EnsembledLineCutter { private int subsampleSize; private SubagSequence ss; /** * Constructor. */ public SubaggedLineCutter() { this(false); } /** * Constructor. * * @param isClassification {@code true} if it is a classification problem. */ public SubaggedLineCutter(boolean isClassification) { attIndex = -1; this.isClassification = isClassification; } /** * Creates internal subsamples. * * @param n the size of the data set. * @param subsampleRatio the subsample ratio. * @param baggingIters the number of bagging iterations. */ public void createSubags(int n, double subsampleRatio, int baggingIters) { this.subsampleSize = (int) (n * subsampleRatio); ss = new SubagSequence(n, subsampleSize, baggingIters); } @Override public BaggedEnsemble build(Instances instances) { return build(instances, attIndex, numIntervals); } class Histogram { double[][] histogram; Histogram(double[][] histogram) { this.histogram = histogram; } Histogram copy() { double[][] newHistogram = new double[histogram.length][histogram[0].length]; for (int i = 0; i < histogram.length; i++) { double[] hist = histogram[i]; double[] newHist = newHistogram[i]; for (int j = 0; j < hist.length; j++) { newHist[j] = hist[j]; } } return new Histogram(newHistogram); } } /** * Builds an 1D function ensemble. * * @param instances the training set. * @param attIndex the attribute index. * @param numIntervals the number of intervals. * @return an 1D function ensemble. */ public BaggedEnsemble build(Instances instances, int attIndex, int numIntervals) { Attribute attribute = instances.getAttributes().get(attIndex); return build(instances, attribute, numIntervals); } /** * Builds an 1D function ensemble. * * @param instances the training set. * @param attribute the attribute. * @param numIntervals the number of intervals. * @return an 1D function ensemble. */ public BaggedEnsemble build(Instances instances, Attribute attribute, int numIntervals) { if (ss == null) { ss = new SubagSequence(instances.size(), subsampleSize, baggingIters); } SubagSequence.Sample[] samples = ss.samples; int[] start = ss.start; int[] end = ss.end; SampleDelta[] deltas = ss.deltas; int[] count = Arrays.copyOf(ss.count, ss.count.length); Function1D[] funcs = new Function1D[ss.samples.length]; int attIndex = attribute.getIndex(); double[] targets = new double[instances.size()]; double[] fvalues = new double[instances.size()]; double[] weights = new double[instances.size()]; for (int i = 0; i < instances.size(); i++) { Instance instance = instances.get(i); targets[i] = instance.getTarget(); fvalues[i] = instance.getValue(attribute); weights[i] = instance.getWeight(); } if (attribute.getType() == Attribute.Type.NUMERIC) { throw new RuntimeException("Not implemented yet!"); } else { int size = 0; if (attribute.getType() == Attribute.Type.BINNED) { size = ((BinnedAttribute) attribute).getNumBins(); } else { size = ((NominalAttribute) attribute).getCardinality(); } Histogram[] histograms = new Histogram[ss.samples.length]; {// Initialization double[][] histogram = new double[size + 1][2]; for (int index : samples[start[0]].indices) { double value = fvalues[index]; double weight = weights[index]; double target = targets[index]; int idx = histogram.length - 1; if (!Double.isNaN(value)) { idx = (int) value; } if (isClassification) { histogram[idx][0] += target; } else { histogram[idx][0] += target * weight; } histogram[idx][1] += weight; } histograms[0] = new Histogram(histogram); funcs[0] = buildFromHistogram(attIndex, histograms[0]); } for (int k = 0; k < start.length; k++) { int from = start[k]; Histogram histStart = histograms[from]; count[from]--; // Generate histogram for its child int to = end[k]; Histogram histEnd = null; if (count[from] == 0) { histEnd = histStart; } else { histEnd = histStart.copy(); } SampleDelta delta = deltas[k]; for (int index : delta.toAdd) { double value = fvalues[index]; double weight = weights[index]; double target = targets[index]; int idx = histEnd.histogram.length - 1; if (!Double.isNaN(value)) { idx = (int) value; } if (isClassification) { histEnd.histogram[idx][0] += target; } else { histEnd.histogram[idx][0] += target * weight; } histEnd.histogram[idx][1] += weight; } for (int index : delta.toDel) { double value = fvalues[index]; double weight = weights[index]; double target = targets[index]; int idx = histEnd.histogram.length - 1; if (!Double.isNaN(value)) { idx = (int) value; } if (isClassification) { histEnd.histogram[idx][0] -= target; } else { histEnd.histogram[idx][0] -= target * weight; } histEnd.histogram[idx][1] -= weight; } histograms[to] = histEnd; funcs[to] = buildFromHistogram(attIndex, histEnd); } } BaggedEnsemble ensemble = new BaggedEnsemble(ss.samples.length); for (Function1D func : funcs) { ensemble.add(func); } return ensemble; } protected Function1D buildFromHistogram(int attIndex, Histogram histogram) { final int size = histogram.histogram.length; List histograms = new ArrayList<>(size); for (int i = 0; i < size - 1; i++) { if (!MathUtils.isZero(histogram.histogram[i][1])) { double[] hist = histogram.histogram[i]; histograms.add(new double[] {i, hist[0], hist[1]}); } } histograms.add(new double[] { Double.NaN, histogram.histogram[size - 1][0], histogram.histogram[size - 1][1] }); Function1D func = LineCutter.build(attIndex, histograms, numIntervals); return func; } } ================================================ FILE: src/main/java/mltk/predictor/function/UnivariateFunction.java ================================================ package mltk.predictor.function; /** * Interface for univariate functions. * * @author Yin Lou * */ public interface UnivariateFunction { /** * Computes the value for the function. * * @param x the argument. * @return the value for the function. */ public double evaluate(double x); } ================================================ FILE: src/main/java/mltk/predictor/function/package-info.java ================================================ /** * Provides classes for simple functions, such as univariate/bivariate functions. */ package mltk.predictor.function; ================================================ FILE: src/main/java/mltk/predictor/gam/DenseDesignMatrix.java ================================================ package mltk.predictor.gam; import java.util.HashSet; import java.util.Set; import mltk.predictor.function.CubicSpline; import mltk.util.StatUtils; import mltk.util.VectorUtils; class DenseDesignMatrix { double[][][] x; double[][] knots; double[][] std; DenseDesignMatrix(double[][][] x, double[][] knots, double[][] std) { this.x = x; this.knots = knots; this.std = std; } static DenseDesignMatrix createCubicSplineDesignMatrix(double[][] dataset, double[] stdList, int numKnots) { final int n = dataset[0].length; final int p = dataset.length; double[][][] x = new double[p][][]; double[][] knots = new double[p][]; double[][] std = new double[p][]; double factor = Math.sqrt(n); for (int j = 0; j < dataset.length; j++) { Set uniqueValues = new HashSet<>(); double[] x1 = dataset[j]; for (int i = 0; i < n; i++) { uniqueValues.add(x1[i]); } int nKnots = uniqueValues.size() <= numKnots ? 0 : numKnots; knots[j] = new double[nKnots]; if (nKnots != 0) { x[j] = new double[nKnots + 3][]; std[j] = new double[nKnots + 3]; } else { x[j] = new double[1][]; std[j] = new double[1]; } double[][] t = x[j]; t[0] = x1; std[j][0] = stdList[j] / factor; if (nKnots != 0) { double[] x2 = new double[n]; for (int i = 0; i < n; i++) { x2[i] = x1[i] * x1[i]; } t[1] = x2; double[] x3 = new double[n]; for (int i = 0; i < n; i++) { x3[i] = x2[i] * x1[i]; } t[2] = x3; std[j][1] = StatUtils.sd(x2) / factor; std[j][2] = StatUtils.sd(x3) / factor; double max = StatUtils.max(x1); double min = StatUtils.min(x1); double stepSize = (max - min) / nKnots; for (int k = 0; k < nKnots; k++) { knots[j][k] = min + stepSize * k; double[] basis = new double[n]; for (int i = 0; i < n; i++) { basis[i] = CubicSpline.h(x1[i], knots[j][k]); } std[j][k + 3] = StatUtils.sd(basis) / factor; t[k + 3] = basis; } } } // Normalize the inputs for (int j = 0; j < p; j++) { double[][] block = x[j]; double[] s = std[j]; for (int i = 0; i < block.length; i++) { VectorUtils.divide(block[i], s[i]); } } return new DenseDesignMatrix(x, knots, std); } } ================================================ FILE: src/main/java/mltk/predictor/gam/GA2MLearner.java ================================================ package mltk.predictor.gam; import java.io.BufferedReader; import java.io.FileReader; import java.util.ArrayList; import java.util.List; import mltk.cmdline.Argument; import mltk.cmdline.CmdLineParser; import mltk.cmdline.options.HoldoutValidatedLearnerWithTaskOptions; import mltk.core.Attribute; import mltk.core.BinnedAttribute; import mltk.core.Instance; import mltk.core.Instances; import mltk.core.NominalAttribute; import mltk.core.Sampling; import mltk.core.Attribute.Type; import mltk.core.io.InstancesReader; import mltk.predictor.BaggedEnsemble; import mltk.predictor.BaggedEnsembleLearner; import mltk.predictor.BoostedEnsemble; import mltk.predictor.HoldoutValidatedLearner; import mltk.predictor.Regressor; import mltk.predictor.evaluation.ConvergenceTester; import mltk.predictor.evaluation.Metric; import mltk.predictor.evaluation.MetricFactory; import mltk.predictor.evaluation.SimpleMetric; import mltk.predictor.function.Array2D; import mltk.predictor.function.CompressionUtils; import mltk.predictor.function.Function2D; import mltk.predictor.function.SquareCutter; import mltk.predictor.io.PredictorReader; import mltk.predictor.io.PredictorWriter; import mltk.util.OptimUtils; import mltk.util.Random; import mltk.util.tuple.IntPair; /** * Class for learning GA^2M models via gradient boosting. * *

* Reference:
* Y. Lou, R. Caruana, J. Gehrke, and G. Hooker. Accurate intelligible models with pairwise interactions. In * Proceedings of the 19th ACM SIGKDD International Conference on Knowledge Discovery and Data Mining (KDD), * Chicago, IL, USA, 2013. *

* * @author Yin Lou * */ public class GA2MLearner extends HoldoutValidatedLearner { static class Options extends HoldoutValidatedLearnerWithTaskOptions { @Argument(name = "-i", description = "input model path", required = true) String inputModelPath = null; @Argument(name = "-I", description = "list of pairwise interactions path", required = true) String interactionsPath = null; @Argument(name = "-m", description = "maximum number of iterations", required = true) int maxNumIters = -1; @Argument(name = "-b", description = "bagging iterations (default: 100)") int baggingIters = 100; @Argument(name = "-s", description = "seed of the random number generator (default: 0)") long seed = 0L; @Argument(name = "-l", description = "learning rate (default: 0.01)") double learningRate = 0.01; } /** * Trains a GA2M. * *
	 * Usage: mltk.predictor.gam.GA2MLearner
	 * -t	train set path
	 * -i	input model path
	 * -I	list of pairwise interactions path
	 * -m	maximum number of iterations
	 * [-g]	task between classification (c) and regression (r) (default: r)
	 * [-v]	valid set path
	 * [-e]	evaluation metric (default: default metric of task)
	 * [-S]	convergence criteria (default: -1)
	 * [-r]	attribute file path
	 * [-o]	output model path
	 * [-V]	verbose (default: true)
	 * [-b]	bagging iterations (default: 100)
	 * [-s]	seed of the random number generator (default: 0)
	 * [-l]	learning rate (default: 0.01)
	 * 
* * @param args the command line arguments. * @throws Exception */ public static void main(String[] args) throws Exception { Options opts = new Options(); CmdLineParser parser = new CmdLineParser(GA2MLearner.class, opts); Task task = null; Metric metric = null; try { parser.parse(args); task = Task.get(opts.task); if (opts.metric == null) { metric = task.getDefaultMetric(); } else { metric = MetricFactory.getMetric(opts.metric); } } catch (IllegalArgumentException e) { parser.printUsage(); System.exit(1); } Random.getInstance().setSeed(opts.seed); ConvergenceTester ct = ConvergenceTester.parse(opts.cc); Instances trainSet = InstancesReader.read(opts.attPath, opts.trainPath); List terms = new ArrayList<>(); BufferedReader br = new BufferedReader(new FileReader(opts.interactionsPath)); for (;;) { String line = br.readLine(); if (line == null) { break; } String[] data = line.split("\\s+"); IntPair term = new IntPair(Integer.parseInt(data[0]), Integer.parseInt(data[1])); terms.add(term); } br.close(); GAM gam = PredictorReader.read(opts.inputModelPath, GAM.class); GA2MLearner learner = new GA2MLearner(); learner.setBaggingIters(opts.baggingIters); learner.setGAM(gam); learner.setMaxNumIters(opts.maxNumIters); learner.setTask(task); learner.setMetric(metric); learner.setConvergenceTester(ct); learner.setPairs(terms); learner.setLearningRate(opts.learningRate); learner.setVerbose(opts.verbose); if (opts.validPath != null) { Instances validSet = InstancesReader.read(opts.attPath, opts.validPath); learner.setValidSet(validSet); } long start = System.currentTimeMillis(); learner.build(trainSet); long end = System.currentTimeMillis(); System.out.println("Time: " + (end - start) / 1000.0); if (opts.outputModelPath != null) { PredictorWriter.write(gam, opts.outputModelPath); } } private int baggingIters; private int maxNumIters; private Task task; private double learningRate; private GAM gam; private List pairs; /** * Constructor. */ public GA2MLearner() { verbose = false; baggingIters = 100; maxNumIters = -1; learningRate = 0.01; task = Task.REGRESSION; metric = task.getDefaultMetric(); } /** * Returns the number of bagging iterations. * * @return the number of bagging iterations. */ public int getBaggingIters() { return baggingIters; } /** * Sets the number of bagging iterations. * * @param baggingIters the number of bagging iterations. */ public void setBaggingIters(int baggingIters) { this.baggingIters = baggingIters; } /** * Returns the maximum number of iterations. * * @return the maximum number of iterations. */ public int getMaxNumIters() { return maxNumIters; } /** * Sets the maximum number of iterations. * * @param maxNumIters the maximum number of iterations. */ public void setMaxNumIters(int maxNumIters) { this.maxNumIters = maxNumIters; } /** * Returns the learning rate. * * @return the learning rate. */ public double getLearningRate() { return learningRate; } /** * Sets the learning rate. * * @param learningRate the learning rate. */ public void setLearningRate(double learningRate) { this.learningRate = learningRate; } /** * Returns the task of this learner. * * @return the task of this learner. */ public Task getTask() { return task; } /** * Sets the task of this learner. * * @param task the new task. */ public void setTask(Task task) { this.task = task; } /** * Returns the GAM. * * @return the GAM. */ public GAM getGAM() { return gam; } /** * Sets the GAM. * * @param gam the GAM. */ public void setGAM(GAM gam) { this.gam = gam; } /** * Returns the list of feature interaction pairs. * * @return the list of feature interaction pairs. */ public List getPairs() { return pairs; } /** * Sets the list of feature interaction pairs. * * @param pairs the list of feature interaction pairs. */ public void setPairs(List pairs) { this.pairs = pairs; } /** * Builds a classifier. * * @param gam the GAM. * @param terms the list of feature interaction pairs. * @param trainSet the training set. * @param validSet the validation set. * @param maxNumIters the maximum number of iterations. */ public void buildClassifier(GAM gam, List terms, Instances trainSet, Instances validSet, int maxNumIters) { List regressors = new ArrayList<>(terms.size()); int[] indices = new int[terms.size()]; for (int i = 0; i < indices.length; i++) { indices[i] = indexOf(gam.terms, terms.get(i)); regressors.add(new BoostedEnsemble()); } // Backup targets double[] target = new double[trainSet.size()]; for (int i = 0; i < target.length; i++) { target[i] = trainSet.get(i).getTarget(); } // Create bags Instances[] bags = Sampling.createBags(trainSet, baggingIters); SquareCutter cutter = new SquareCutter(true); BaggedEnsembleLearner learner = new BaggedEnsembleLearner(bags.length, cutter); // Initialize predictions and residuals double[] pTrain = new double[trainSet.size()]; double[] rTrain = new double[trainSet.size()]; double[] pValid = new double[validSet.size()]; for (int i = 0; i < pTrain.length; i++) { Instance instance = trainSet.get(i); pTrain[i] = gam.regress(instance); } OptimUtils.computePseudoResidual(pTrain, target, rTrain); for (int i = 0; i < pValid.length; i++) { Instance instance = validSet.get(i); pValid[i] = gam.regress(instance); } // Gradient boosting // Resets the convergence tester ct.setMetric(metric); for (int iter = 0; iter < maxNumIters; iter++) { for (int j = 0; j < terms.size(); j++) { // Derivitive to attribute k // Minimizes the loss function: log(1 + exp(-yF)) for (int i = 0; i < trainSet.size(); i++) { trainSet.get(i).setTarget(rTrain[i]); } BoostedEnsemble boostedEnsemble = regressors.get(j); // Train model IntPair term = terms.get(j); cutter.setAttIndices(term.v1, term.v2); BaggedEnsemble baggedEnsemble = learner.build(bags); if (learningRate != 1) { for (int i = 0; i < baggedEnsemble.size(); i++) { Function2D func = (Function2D) baggedEnsemble.get(i); func.multiply(learningRate); } } boostedEnsemble.add(baggedEnsemble); // Update predictions for (int i = 0; i < trainSet.size(); i++) { Instance instance = trainSet.get(i); double pred = baggedEnsemble.regress(instance); pTrain[i] += pred; rTrain[i] = OptimUtils.getPseudoResidual(pTrain[i], target[i]); } for (int i = 0; i < validSet.size(); i++) { Instance instance = validSet.get(i); double pred = baggedEnsemble.regress(instance); pValid[i] += pred; } double measure = metric.eval(pValid, validSet); ct.add(measure); if (verbose) { System.out.println("Iteration " + iter + " term " + j + ": " + measure); } } if (ct.isConverged()) { break; } } // Search the best model on validation set int idx = ct.getBestIndex(); // Remove trees int n = idx / terms.size(); int m = idx % terms.size(); for (int k = 0; k < terms.size(); k++) { BoostedEnsemble boostedEnsemble = regressors.get(k); for (int i = boostedEnsemble.size(); i > n + 1; i--) { boostedEnsemble.removeLast(); } if (k > m) { boostedEnsemble.removeLast(); } } // Restore targets for (int i = 0; i < target.length; i++) { trainSet.get(i).setTarget(target[i]); } // Compress model List attributes = trainSet.getAttributes(); for (int i = 0; i < regressors.size(); i++) { BoostedEnsemble boostedEnsemble = regressors.get(i); IntPair term = terms.get(i); Attribute f1 = attributes.get(term.v1); Attribute f2 = attributes.get(term.v2); int n1 = -1; if (f1.getType() == Type.BINNED) { n1 = ((BinnedAttribute) f1).getNumBins(); } else if (f1.getType() == Type.NOMINAL) { n1 = ((NominalAttribute) f1).getCardinality(); } int n2 = -1; if (f2.getType() == Type.BINNED) { n2 = ((BinnedAttribute) f2).getNumBins(); } else if (f2.getType() == Type.NOMINAL) { n2 = ((NominalAttribute) f2).getCardinality(); } Array2D newRegressor = CompressionUtils.compress(term.v1, term.v2, n1, n2, boostedEnsemble); if (indices[i] < 0) { gam.add(new int[] { term.v1, term.v2 }, newRegressor); } else { Regressor regressor = gam.regressors.get(indices[i]); if (regressor instanceof Array2D) { Array2D ary = (Array2D) regressor; ary.add(newRegressor); } else { throw new RuntimeException("Failed to add new regressor"); } } } } /** * Builds a classifier. * * @param gam the GAM. * @param terms the list of feature interaction pairs. * @param trainSet the training set. * @param maxNumIters the maximum number of iterations. */ public void buildClassifier(GAM gam, List terms, Instances trainSet, int maxNumIters) { SimpleMetric simpleMetric = (SimpleMetric) metric; List regressors = new ArrayList<>(); int[] indices = new int[terms.size()]; for (int i = 0; i < indices.length; i++) { indices[i] = indexOf(gam.terms, terms.get(i)); regressors.add(new BoostedEnsemble()); } // Backup targets double[] target = new double[trainSet.size()]; for (int i = 0; i < target.length; i++) { target[i] = trainSet.get(i).getTarget(); } // Create bags Instances[] bags = Sampling.createBags(trainSet, baggingIters); SquareCutter cutter = new SquareCutter(true); BaggedEnsembleLearner learner = new BaggedEnsembleLearner(bags.length, cutter); // Initialize predictions and residuals double[] pTrain = new double[trainSet.size()]; double[] rTrain = new double[trainSet.size()]; for (int i = 0; i < pTrain.length; i++) { Instance instance = trainSet.get(i); pTrain[i] = gam.regress(instance); } OptimUtils.computePseudoResidual(pTrain, target, rTrain); // Gradient boosting for (int iter = 0; iter < maxNumIters; iter++) { for (int j = 0; j < terms.size(); j++) { // Derivitive to attribute k // Minimizes the loss function: log(1 + exp(-yF)) for (int i = 0; i < trainSet.size(); i++) { trainSet.get(i).setTarget(rTrain[i]); } BoostedEnsemble boostedEnsemble = regressors.get(j); // Train model IntPair term = terms.get(j); cutter.setAttIndices(term.v1, term.v2); BaggedEnsemble baggedEnsemble = learner.build(bags); if (learningRate != 1) { for (int i = 0; i < baggedEnsemble.size(); i++) { Function2D func = (Function2D) baggedEnsemble.get(i); func.multiply(learningRate); } } boostedEnsemble.add(baggedEnsemble); // Update predictions for (int i = 0; i < trainSet.size(); i++) { Instance instance = trainSet.get(i); double pred = baggedEnsemble.regress(instance); pTrain[i] += pred; rTrain[i] = OptimUtils.getPseudoResidual(pTrain[i], target[i]); } double measure = simpleMetric.eval(pTrain, target); if (verbose) { System.out.println("Iteration " + iter + " term " + j + ": " + measure); } } } // Restore targets for (int i = 0; i < target.length; i++) { trainSet.get(i).setTarget(target[i]); } // Compress model List attributes = trainSet.getAttributes(); for (int i = 0; i < regressors.size(); i++) { BoostedEnsemble boostedEnsemble = regressors.get(i); IntPair term = terms.get(i); Attribute f1 = attributes.get(term.v1); Attribute f2 = attributes.get(term.v2); int n1 = -1; if (f1.getType() == Type.BINNED) { n1 = ((BinnedAttribute) f1).getNumBins(); } else if (f1.getType() == Type.NOMINAL) { n1 = ((NominalAttribute) f1).getCardinality(); } int n2 = -1; if (f2.getType() == Type.BINNED) { n2 = ((BinnedAttribute) f2).getNumBins(); } else if (f2.getType() == Type.NOMINAL) { n2 = ((NominalAttribute) f2).getCardinality(); } Array2D newRegressor = CompressionUtils.compress(term.v1, term.v2, n1, n2, boostedEnsemble); if (indices[i] < 0) { gam.add(new int[] { term.v1, term.v2 }, newRegressor); } else { Regressor regressor = gam.regressors.get(indices[i]); if (regressor instanceof Array2D) { Array2D ary = (Array2D) regressor; ary.add(newRegressor); } else { throw new RuntimeException("Failed to add new regressor"); } } } } /** * Builds a regressor. * * @param gam the GAM. * @param terms the list of feature interaction pairs. * @param trainSet the training set. * @param validSet the validation set. * @param maxNumIters the maximum number of iterations. */ public void buildRegressor(GAM gam, List terms, Instances trainSet, Instances validSet, int maxNumIters) { List regressors = new ArrayList<>(); int[] indices = new int[terms.size()]; for (int i = 0; i < indices.length; i++) { indices[i] = indexOf(gam.terms, terms.get(i)); regressors.add(new BoostedEnsemble()); } // Backup targets double[] target = new double[trainSet.size()]; for (int i = 0; i < target.length; i++) { target[i] = trainSet.get(i).getTarget(); } // Create bags Instances[] bags = Sampling.createBags(trainSet, baggingIters); SquareCutter cutter = new SquareCutter(); BaggedEnsembleLearner learner = new BaggedEnsembleLearner(baggingIters, cutter); // Initialize predictions and residuals double[] rTrain = new double[trainSet.size()]; double[] pValid = new double[validSet.size()]; double[] rValid = new double[validSet.size()]; for (int i = 0; i < trainSet.size(); i++) { Instance instance = trainSet.get(i); rTrain[i] = instance.getTarget() - gam.regress(instance); } for (int i = 0; i < validSet.size(); i++) { Instance instance = validSet.get(i); pValid[i] = gam.regress(instance); rValid[i] = instance.getTarget() - pValid[i]; } // Gradient boosting // Resets the convergence tester ct.setMetric(metric); for (int iter = 0; iter < maxNumIters; iter++) { for (int j = 0; j < terms.size(); j++) { // Derivative to attribute k // Equivalent to residual BoostedEnsemble boostedEnsemble = regressors.get(j); // Prepare training set for (int i = 0; i < rTrain.length; i++) { trainSet.get(i).setTarget(rTrain[i]); } // Train model IntPair term = terms.get(j); cutter.setAttIndices(term.v1, term.v2); BaggedEnsemble baggedEnsemble = learner.build(bags); if (learningRate != 1) { for (int i = 0; i < baggedEnsemble.size(); i++) { Function2D func = (Function2D) baggedEnsemble.get(i); func.multiply(learningRate); } } boostedEnsemble.add(baggedEnsemble); // Update residuals for (int i = 0; i < rTrain.length; i++) { Instance instance = trainSet.get(i); double pred = baggedEnsemble.regress(instance); rTrain[i] -= pred; } for (int i = 0; i < rValid.length; i++) { Instance instance = validSet.get(i); double pred = baggedEnsemble.regress(instance); pValid[i] += pred; rValid[i] -= pred; } double measure = metric.eval(pValid, validSet); ct.add(measure); if (verbose) { System.out.println("Iteration " + iter + " term " + j + ":" + measure); } } if (ct.isConverged()) { break; } } // Search the best model on validation set int idx = ct.getBestIndex(); // Remove trees int n = idx / terms.size(); int m = idx % terms.size(); for (int k = 0; k < terms.size(); k++) { BoostedEnsemble boostedEnsemble = regressors.get(k); for (int i = boostedEnsemble.size(); i > n + 1; i--) { boostedEnsemble.removeLast(); } if (k > m) { boostedEnsemble.removeLast(); } } // Restore targets for (int i = 0; i < target.length; i++) { trainSet.get(i).setTarget(target[i]); } // Compress model List attributes = trainSet.getAttributes(); for (int i = 0; i < regressors.size(); i++) { BoostedEnsemble boostedEnsemble = regressors.get(i); IntPair term = terms.get(i); Attribute f1 = attributes.get(term.v1); Attribute f2 = attributes.get(term.v2); int n1 = -1; if (f1.getType() == Type.BINNED) { n1 = ((BinnedAttribute) f1).getNumBins(); } else if (f1.getType() == Type.NOMINAL) { n1 = ((NominalAttribute) f1).getCardinality(); } int n2 = -1; if (f2.getType() == Type.BINNED) { n2 = ((BinnedAttribute) f2).getNumBins(); } else if (f2.getType() == Type.NOMINAL) { n2 = ((NominalAttribute) f2).getCardinality(); } Array2D newRegressor = CompressionUtils.compress(term.v1, term.v2, n1, n2, boostedEnsemble); if (indices[i] < 0) { gam.add(new int[] { term.v1, term.v2 }, newRegressor); } else { Regressor regressor = gam.regressors.get(indices[i]); if (regressor instanceof Array2D) { Array2D ary = (Array2D) regressor; ary.add(newRegressor); } else { throw new RuntimeException("Failed to add new regressor"); } } } } /** * Builds a regressor. * * @param gam the GAM. * @param terms the list of feature interaction pairs. * @param trainSet the training set. * @param maxNumIters the maximum number of iterations. */ public void buildRegressor(GAM gam, List terms, Instances trainSet, int maxNumIters) { SimpleMetric simpleMetric = (SimpleMetric) metric; List regressors = new ArrayList<>(); int[] indices = new int[terms.size()]; for (int i = 0; i < indices.length; i++) { indices[i] = indexOf(gam.terms, terms.get(i)); regressors.add(new BoostedEnsemble()); } // Backup targets double[] target = new double[trainSet.size()]; for (int i = 0; i < target.length; i++) { target[i] = trainSet.get(i).getTarget(); } // Create bags Instances[] bags = Sampling.createBags(trainSet, baggingIters); SquareCutter cutter = new SquareCutter(); BaggedEnsembleLearner learner = new BaggedEnsembleLearner(baggingIters, cutter); // Initialize predictions and residuals double[] pTrain = new double[trainSet.size()]; double[] rTrain = new double[trainSet.size()]; for (int i = 0; i < trainSet.size(); i++) { Instance instance = trainSet.get(i); pTrain[i] = gam.regress(instance); rTrain[i] = instance.getTarget() - pTrain[i]; } // Gradient boosting for (int iter = 0; iter < maxNumIters; iter++) { for (int j = 0; j < terms.size(); j++) { // Derivative to attribute k // Equivalent to residual BoostedEnsemble boostedEnsemble = regressors.get(j); // Prepare training set for (int i = 0; i < rTrain.length; i++) { trainSet.get(i).setTarget(rTrain[i]); } // Train model IntPair term = terms.get(j); cutter.setAttIndices(term.v1, term.v2); BaggedEnsemble baggedEnsemble = learner.build(bags); if (learningRate != 1) { for (int i = 0; i < baggedEnsemble.size(); i++) { Function2D func = (Function2D) baggedEnsemble.get(i); func.multiply(learningRate); } } boostedEnsemble.add(baggedEnsemble); // Update residuals for (int i = 0; i < rTrain.length; i++) { Instance instance = trainSet.get(i); double pred = baggedEnsemble.regress(instance); rTrain[i] -= pred; } double measure = simpleMetric.eval(pTrain, target); if (verbose) { System.out.println("Iteration " + iter + " term " + j + ":" + measure); } } } // Restore targets for (int i = 0; i < target.length; i++) { trainSet.get(i).setTarget(target[i]); } // Compress model List attributes = trainSet.getAttributes(); for (int i = 0; i < regressors.size(); i++) { BoostedEnsemble boostedEnsemble = regressors.get(i); IntPair term = terms.get(i); Attribute f1 = attributes.get(term.v1); Attribute f2 = attributes.get(term.v2); int n1 = -1; if (f1.getType() == Type.BINNED) { n1 = ((BinnedAttribute) f1).getNumBins(); } else if (f1.getType() == Type.NOMINAL) { n1 = ((NominalAttribute) f1).getCardinality(); } int n2 = -1; if (f2.getType() == Type.BINNED) { n2 = ((BinnedAttribute) f2).getNumBins(); } else if (f2.getType() == Type.NOMINAL) { n2 = ((NominalAttribute) f2).getCardinality(); } Array2D newRegressor = CompressionUtils.compress(term.v1, term.v2, n1, n2, boostedEnsemble); if (indices[i] < 0) { gam.add(new int[] { term.v1, term.v2 }, newRegressor); } else { Regressor regressor = gam.regressors.get(indices[i]); if (regressor instanceof Array2D) { Array2D ary = (Array2D) regressor; ary.add(newRegressor); } else { throw new RuntimeException("Failed to add new regressor"); } } } } @Override public GAM build(Instances instances) { if (pairs == null) { int p = instances.dimension(); pairs = new ArrayList(); for (int i = 0; i < p; i++) { for (int j = i + 1; j < p; j++) { pairs.add(new IntPair(i, j)); } } } if (maxNumIters < 0) { maxNumIters = 20; } switch (task) { case REGRESSION: if (validSet != null) { buildRegressor(gam, pairs, instances, validSet, maxNumIters); } else { buildRegressor(gam, pairs, instances, maxNumIters); } break; case CLASSIFICATION: if (validSet != null) { buildClassifier(gam, pairs, instances, validSet, maxNumIters); } else { buildClassifier(gam, pairs, instances, maxNumIters); } break; default: break; } return gam; } private int indexOf(List terms, IntPair pair) { for (int i = 0; i < terms.size(); i++) { int[] term = terms.get(i); if (term.length == 2 && term[0] == pair.v1 && term[1] == pair.v2) { return i; } } return -1; } } ================================================ FILE: src/main/java/mltk/predictor/gam/GAM.java ================================================ package mltk.predictor.gam; import java.io.BufferedReader; import java.io.PrintWriter; import java.util.ArrayList; import java.util.Arrays; import java.util.Iterator; import java.util.List; import mltk.core.Instance; import mltk.predictor.ProbabilisticClassifier; import mltk.predictor.Regressor; import mltk.util.ArrayUtils; import mltk.util.MathUtils; /** * Class for generalized additive models (GAMs). * * @author Yin Lou * */ public class GAM implements ProbabilisticClassifier, Regressor { class RegressorList implements Iterable { List regressors; RegressorList() { regressors = new ArrayList<>(); } @Override public Iterator iterator() { return regressors.iterator(); } } class TermList implements Iterable { List terms; TermList() { terms = new ArrayList<>(); } @Override public Iterator iterator() { return terms.iterator(); } } protected double intercept; protected List regressors; protected List terms; /** * Constructor. */ public GAM() { regressors = new ArrayList<>(); terms = new ArrayList<>(); intercept = 0; } /** * Returns the intercept. * * @return the intercept. */ public double getIntercept() { return intercept; } /** * Sets the intercept. * * @param intercept the new intercept. */ public void setIntercept(double intercept) { this.intercept = intercept; } @Override public void read(BufferedReader in) throws Exception { intercept = Double.parseDouble(in.readLine().split(": ")[1]); int size = Integer.parseInt(in.readLine().split(": ")[1]); regressors = new ArrayList<>(size); terms = new ArrayList<>(size); in.readLine(); for (int i = 0; i < size; i++) { int[] term = ArrayUtils.parseIntArray(in.readLine().split(": ")[1]); terms.add(term); in.readLine(); String line = in.readLine(); String regressorName = line.substring(1, line.length() - 1).split(": ")[1]; Class clazz = Class.forName(regressorName); Regressor regressor = (Regressor) clazz.getDeclaredConstructor().newInstance(); regressor.read(in); regressors.add(regressor); in.readLine(); } } @Override public void write(PrintWriter out) throws Exception { out.printf("[Predictor: %s]\n", this.getClass().getCanonicalName()); out.println("Intercept: " + intercept); out.println("Components: " + regressors.size()); out.println(); for (int i = 0; i < regressors.size(); i++) { out.println("Component: " + Arrays.toString(terms.get(i))); out.println(); regressors.get(i).write(out); out.println(); } } /** * Adds a new term into this GAM. The term is an array of attribute indices that are used in the regressor. * * @param term the new term to add. * @param regressor the new regressor to add. */ public void add(int[] term, Regressor regressor) { terms.add(term); regressors.add(regressor); } @Override public double regress(Instance instance) { double pred = intercept; for (Regressor regressor : regressors) { pred += regressor.regress(instance); } return pred; } @Override public int classify(Instance instance) { double pred = regress(instance); return pred >= 0 ? 1 : 0; } @Override public double[] predictProbabilities(Instance instance) { double pred = regress(instance); double prob = MathUtils.sigmoid(pred); return new double[] { 1 - prob, prob }; } /** * Returns the term list. * * @return the term list. */ public List getTerms() { return terms; } /** * Returns the regressor list. * * @return the regressor list. */ public List getRegressors() { return regressors; } @Override public GAM copy() { GAM copy = new GAM(); copy.intercept = intercept; for (Regressor regressor : regressors) { copy.regressors.add((Regressor) regressor.copy()); } for (int[] term : terms) { copy.terms.add(term.clone()); } return copy; } } ================================================ FILE: src/main/java/mltk/predictor/gam/GAMLearner.java ================================================ package mltk.predictor.gam; import java.util.ArrayList; import java.util.List; import mltk.cmdline.Argument; import mltk.cmdline.CmdLineParser; import mltk.cmdline.options.HoldoutValidatedLearnerWithTaskOptions; import mltk.core.Attribute; import mltk.core.BinnedAttribute; import mltk.core.Instance; import mltk.core.Instances; import mltk.core.NominalAttribute; import mltk.core.Attribute.Type; import mltk.core.io.InstancesReader; import mltk.predictor.BaggedEnsemble; import mltk.predictor.BoostedEnsemble; import mltk.predictor.HoldoutValidatedLearner; import mltk.predictor.Regressor; import mltk.predictor.evaluation.ConvergenceTester; import mltk.predictor.evaluation.Metric; import mltk.predictor.evaluation.MetricFactory; import mltk.predictor.evaluation.SimpleMetric; import mltk.predictor.function.BaggedLineCutter; import mltk.predictor.function.CompressionUtils; import mltk.predictor.function.EnsembledLineCutter; import mltk.predictor.function.Function1D; import mltk.predictor.function.SubaggedLineCutter; import mltk.predictor.io.PredictorWriter; import mltk.util.OptimUtils; import mltk.util.Random; /** * Class for learning GAMs via gradient tree boosting. * *

* Reference:
* Y. Lou, R. Caruana and J. Gehrke. Intelligible models for classification and regression. In Proceedings of the * 18th ACM SIGKDD International Conference on Knowledge Discovery and Data Mining (KDD), Beijing, China, 2012.
* * Y. Lou, Y. Wang, S. Liang and Y. Dong. Efficiently Training Intelligible Models for Global Explanations. * In Proceedings of the 29th ACM International Conference on Information and Knowledge Management (CIKM), * Virtual Event, Ireland, 2020. *

* * @author Yin Lou * */ public class GAMLearner extends HoldoutValidatedLearner { static class Options extends HoldoutValidatedLearnerWithTaskOptions { @Argument(name = "-b", description = "base learner (default: tr:3:100:0.65)") String baseLearner = "tr:3:100:0.65"; @Argument(name = "-m", description = "maximum number of iterations", required = true) int maxNumIters = -1; @Argument(name = "-s", description = "seed of the random number generator (default: 0)") long seed = 0L; @Argument(name = "-l", description = "learning rate (default: 0.01)") double learningRate = 0.01; } /** * Trains a GAM. * *
	 * Usage: mltk.predictor.gam.GAMLearner
	 * -t	train set path
	 * -m	maximum number of iterations
	 * [-g]	task between classification (c) and regression (r) (default: r)
	 * [-v]	valid set path
	 * [-e]	evaluation metric (default: default metric of task)
	 * [-S]	convergence criteria (default: -1)
	 * [-r]	attribute file path
	 * [-o]	output model path
	 * [-V]	verbose (default: true)
	 * [-b]	base learner (default: tr:3:100)
	 * [-s]	seed of the random number generator (default: 0)
	 * [-l]	learning rate (default: 0.01)
	 * 
* * @param args the command line arguments. * @throws Exception */ public static void main(String[] args) throws Exception { Options opts = new Options(); CmdLineParser parser = new CmdLineParser(GAMLearner.class, opts); Task task = null; Metric metric = null; try { parser.parse(args); task = Task.get(opts.task); if (opts.metric == null) { metric = task.getDefaultMetric(); } else { metric = MetricFactory.getMetric(opts.metric); } } catch (IllegalArgumentException e) { parser.printUsage(); System.exit(1); } Random.getInstance().setSeed(opts.seed); ConvergenceTester ct = ConvergenceTester.parse(opts.cc); Instances trainSet = InstancesReader.read(opts.attPath, opts.trainPath); GAMLearner learner = new GAMLearner(); learner.setBaseLearner(opts.baseLearner); learner.setMaxNumIters(opts.maxNumIters); learner.setLearningRate(opts.learningRate); learner.setTask(task); learner.setMetric(metric); learner.setConvergenceTester(ct); learner.setVerbose(opts.verbose); if (opts.validPath != null) { Instances validSet = InstancesReader.read(opts.attPath, opts.validPath); learner.setValidSet(validSet); } long start = System.currentTimeMillis(); GAM gam = learner.build(trainSet); long end = System.currentTimeMillis(); System.out.println("Time: " + (end - start) / 1000.0); if (opts.outputModelPath != null) { PredictorWriter.write(gam, opts.outputModelPath); } } private int baggingIters; private int maxNumIters; private int maxNumLeaves; private Task task; private double alpha; private double learningRate; /** * Constructor. */ public GAMLearner() { verbose = false; baggingIters = 100; maxNumIters = -1; maxNumLeaves = 3; alpha = 0.65; learningRate = 0.01; task = Task.REGRESSION; metric = task.getDefaultMetric(); } /** * Returns the number of bagging iterations. * * @return the number of bagging iterations. */ public int getBaggingIters() { return baggingIters; } /** * Sets the number of bagging iterations. * * @param baggingIters the number of bagging iterations. */ public void setBaggingIters(int baggingIters) { this.baggingIters = baggingIters; } /** * Returns the maximum number of iterations. * * @return the maximum number of iterations. */ public int getMaxNumIters() { return maxNumIters; } /** * Sets the maximum number of iterations. * * @param maxNumIters the maximum number of iterations. */ public void setMaxNumIters(int maxNumIters) { this.maxNumIters = maxNumIters; } /** * Returns the maximum number of leaves. * * @return the maximum number of leaves. */ public int getMaxNumLeaves() { return maxNumLeaves; } /** * Sets the maximum number of leaves. * * @param maxNumLeaves the maximum number of leaves. */ public void setMaxNumLeaves(int maxNumLeaves) { this.maxNumLeaves = maxNumLeaves; } /** * Returns the learning rate. * * @return the learning rate. */ public double getLearningRate() { return learningRate; } /** * Sets the learning rate. * * @param learningRate the learning rate. */ public void setLearningRate(double learningRate) { this.learningRate = learningRate; } /** * Returns the subsampling ratio. * * @return the subsampling ratio. */ public double getSubsamplingRatio() { return alpha; } /** * Sets the subsampling ratio. * * @param alpha the subsampling ratio. */ public void setSubsamplingRatio(double alpha) { this.alpha = alpha; } /** * Returns the task of this learner. * * @return the task of this learner. */ public Task getTask() { return task; } /** * Sets the task of this learner. * * @param task the task of this learner. */ public void setTask(Task task) { this.task = task; } /** * Sets the base learner.
* * @param option the option string. */ public void setBaseLearner(String option) { String[] opts = option.split(":"); switch (opts[0]) { case "tr": int maxNumLeaves = Integer.parseInt(opts[1]); int baggingIters = Integer.parseInt(opts[2]); setMaxNumLeaves(maxNumLeaves); setBaggingIters(baggingIters); if (opts.length > 3) { double alpha = Double.parseDouble(opts[3]); setSubsamplingRatio(alpha); } else { setSubsamplingRatio(-1); } break; case "cs": break; default: break; } } /** * Builds a classifier. * * @param trainSet the training set. * @param validSet the validation set. * @param maxNumIters the maximum number of iterations. * @param maxNumLeaves the maximum number of leaves. * @return a classifier. */ public GAM buildClassifier(Instances trainSet, Instances validSet, int maxNumIters, int maxNumLeaves) { GAM gam = new GAM(); // Backup targets and weights double[] target = new double[trainSet.size()]; double[] weight = new double[trainSet.size()]; for (int i = 0; i < target.length; i++) { Instance instance = trainSet.get(i); target[i] = instance.getTarget(); weight[i] = instance.getWeight(); } List attributes = trainSet.getAttributes(); List regressors = new ArrayList<>(attributes.size()); for (int i = 0; i < attributes.size(); i++) { regressors.add(new BoostedEnsemble()); } // Create bags EnsembledLineCutter elc = null; if (0 <= alpha & alpha <= 1) { SubaggedLineCutter slc = new SubaggedLineCutter(true); slc.createSubags(trainSet.size(), alpha, baggingIters); elc = slc; } else { BaggedLineCutter blc = new BaggedLineCutter(true); blc.createBags(trainSet.size(), baggingIters); elc = blc; } elc.setNumIntervals(maxNumLeaves); // Initialize predictions and residuals double[] predTrain = new double[trainSet.size()]; double[] probTrain = new double[trainSet.size()]; double[] rTrain = new double[trainSet.size()]; OptimUtils.computeProbabilities(predTrain, probTrain); OptimUtils.computePseudoResidual(predTrain, target, rTrain); double[] pValid = new double[validSet.size()]; // Gradient boosting // Resets the convergence tester ct.setMetric(metric); for (int iter = 0; iter < maxNumIters; iter++) { for (int j = 0; j < attributes.size(); j++) { // Derivitive to attribute k // Minimizes the loss function: log(1 + exp(-yF)) for (int i = 0; i < trainSet.size(); i++) { Instance instance = trainSet.get(i); double prob = probTrain[i]; double w = prob * (1 - prob); instance.setTarget(rTrain[i] * weight[i]); instance.setWeight(w * weight[i]); } BoostedEnsemble boostedEnsemble = regressors.get(j); // Train model elc.setAttributeIndex(j); BaggedEnsemble baggedEnsemble = elc.build(trainSet); Function1D func = CompressionUtils.compress(attributes.get(j).getIndex(), baggedEnsemble); if (learningRate != 1) { func.multiply(learningRate); } boostedEnsemble.add(func); baggedEnsemble = null; // Update predictions for (int i = 0; i < trainSet.size(); i++) { Instance instance = trainSet.get(i); double pred = func.regress(instance); predTrain[i] += pred; } OptimUtils.computeProbabilities(predTrain, probTrain); OptimUtils.computePseudoResidual(predTrain, target, rTrain); for (int i = 0; i < validSet.size(); i++) { Instance instance = validSet.get(i); double pred = func.regress(instance); pValid[i] += pred; } double measure = metric.eval(pValid, validSet); ct.add(measure); if (verbose) { System.out.println("Iteration " + iter + " Feature " + j + ": " + measure); } } if (ct.isConverged()) { break; } } // Search the best model on validation set int idx = ct.getBestIndex(); // Remove trees int n = idx / attributes.size(); int m = idx % attributes.size(); for (int k = 0; k < regressors.size(); k++) { BoostedEnsemble boostedEnsemble = regressors.get(k); for (int i = boostedEnsemble.size(); i > n + 1; i--) { boostedEnsemble.removeLast(); } if (k > m) { boostedEnsemble.removeLast(); } } // Restore targets and weights for (int i = 0; i < target.length; i++) { trainSet.get(i).setTarget(target[i]); trainSet.get(i).setWeight(weight[i]); } // Compress model for (int i = 0; i < regressors.size(); i++) { BoostedEnsemble boostedEnsemble = regressors.get(i); Attribute attribute = attributes.get(i); int attIndex = attribute.getIndex(); Function1D function = CompressionUtils.compress(attIndex, boostedEnsemble); Regressor regressor = function; if (attribute.getType() == Type.BINNED) { int l = ((BinnedAttribute) attribute).getNumBins(); regressor = CompressionUtils.convert(l, function); } else if (attribute.getType() == Type.NOMINAL) { int l = ((NominalAttribute) attribute).getCardinality(); regressor = CompressionUtils.convert(l, function); } gam.add(new int[] { attIndex }, regressor); } return gam; } /** * Builds a classifier. * * @param trainSet the training set. * @param maxNumIters the maximum number of iterations. * @param maxNumLeaves the maximum number of leaves. * @return a classifier. */ public GAM buildClassifier(Instances trainSet, int maxNumIters, int maxNumLeaves) { GAM gam = new GAM(); SimpleMetric simpleMetric = (SimpleMetric) metric; // Backup targets and weights double[] target = new double[trainSet.size()]; double[] weight = new double[trainSet.size()]; for (int i = 0; i < target.length; i++) { Instance instance = trainSet.get(i); target[i] = instance.getTarget(); weight[i] = instance.getWeight(); } List attributes = trainSet.getAttributes(); List regressors = new ArrayList<>(attributes.size()); for (int i = 0; i < attributes.size(); i++) { regressors.add(new BoostedEnsemble()); } // Create bags EnsembledLineCutter elc = null; if (0 <= alpha & alpha <= 1) { SubaggedLineCutter slc = new SubaggedLineCutter(true); slc.createSubags(trainSet.size(), alpha, baggingIters); elc = slc; } else { BaggedLineCutter blc = new BaggedLineCutter(true); blc.createBags(trainSet.size(), baggingIters); elc = blc; } elc.setNumIntervals(maxNumLeaves); // Initialize predictions and residuals double[] predTrain = new double[trainSet.size()]; double[] probTrain = new double[trainSet.size()]; double[] rTrain = new double[trainSet.size()]; OptimUtils.computeProbabilities(predTrain, probTrain); OptimUtils.computePseudoResidual(predTrain, target, rTrain); // Gradient boosting for (int iter = 0; iter < maxNumIters; iter++) { for (int j = 0; j < attributes.size(); j++) { // Derivitive to attribute k // Minimizes the loss function: log(1 + exp(-yF)) for (int i = 0; i < trainSet.size(); i++) { Instance instance = trainSet.get(i); double prob = probTrain[i]; double w = prob * (1 - prob); instance.setTarget(rTrain[i] * weight[i]); instance.setWeight(w * weight[i]); } BoostedEnsemble boostedEnsemble = regressors.get(j); // Train model elc.setAttributeIndex(j); BaggedEnsemble baggedEnsemble = elc.build(trainSet); Function1D func = CompressionUtils.compress(attributes.get(j).getIndex(), baggedEnsemble); if (learningRate != 1) { func.multiply(learningRate); } boostedEnsemble.add(func); baggedEnsemble = null; // Update predictions for (int i = 0; i < trainSet.size(); i++) { Instance instance = trainSet.get(i); double pred = func.regress(instance); predTrain[i] += pred; } OptimUtils.computeProbabilities(predTrain, probTrain); OptimUtils.computePseudoResidual(predTrain, target, rTrain); double measure = simpleMetric.eval(predTrain, target); if (verbose) { System.out.println("Iteration " + iter + " Feature " + j + ": " + measure); } } } // Restore targets and weights for (int i = 0; i < target.length; i++) { trainSet.get(i).setTarget(target[i]); trainSet.get(i).setWeight(weight[i]); } // Compress model for (int i = 0; i < regressors.size(); i++) { BoostedEnsemble boostedEnsemble = regressors.get(i); Attribute attribute = attributes.get(i); int attIndex = attribute.getIndex(); Function1D function = CompressionUtils.compress(attIndex, boostedEnsemble); Regressor regressor = function; if (attribute.getType() == Type.BINNED) { int l = ((BinnedAttribute) attribute).getNumBins(); regressor = CompressionUtils.convert(l, function); } else if (attribute.getType() == Type.NOMINAL) { int l = ((NominalAttribute) attribute).getCardinality(); regressor = CompressionUtils.convert(l, function); } gam.add(new int[] { attIndex }, regressor); } return gam; } /** * Builds a regressor. * * @param trainSet the training set. * @param validSet the validation set. * @param maxNumIters the maximum number of iterations. * @param maxNumLeaves the maximum number of leaves. * @return a regressor. */ public GAM buildRegressor(Instances trainSet, Instances validSet, int maxNumIters, int maxNumLeaves) { GAM gam = new GAM(); List attributes = trainSet.getAttributes(); List regressors = new ArrayList<>(attributes.size()); for (int i = 0; i < attributes.size(); i++) { regressors.add(new BoostedEnsemble()); } // Backup targets double[] target = new double[trainSet.size()]; for (int i = 0; i < target.length; i++) { target[i] = trainSet.get(i).getTarget(); } // Create bags EnsembledLineCutter elc = null; if (0 <= alpha & alpha <= 1) { SubaggedLineCutter slc = new SubaggedLineCutter(false); slc.createSubags(trainSet.size(), alpha, baggingIters); elc = slc; } else { BaggedLineCutter blc = new BaggedLineCutter(false); blc.createBags(trainSet.size(), baggingIters); elc = blc; } elc.setNumIntervals(maxNumLeaves); // Initialize predictions and residuals double[] rTrain = new double[trainSet.size()]; double[] pValid = new double[validSet.size()]; double[] rValid = new double[validSet.size()]; for (int i = 0; i < trainSet.size(); i++) { Instance instance = trainSet.get(i); rTrain[i] = instance.getTarget(); } for (int i = 0; i < validSet.size(); i++) { Instance instance = validSet.get(i); rValid[i] = instance.getTarget(); } // Gradient boosting // Resets the convergence tester ct.setMetric(metric); for (int iter = 0; iter < maxNumIters; iter++) { for (int j = 0; j < attributes.size(); j++) { // Derivative to attribute k // Equivalent to residual BoostedEnsemble boostedEnsemble = regressors.get(j); // Prepare training set for (int i = 0; i < rTrain.length; i++) { trainSet.get(i).setTarget(rTrain[i]); } // Train model elc.setAttributeIndex(j); BaggedEnsemble baggedEnsemble = elc.build(trainSet); Function1D func = CompressionUtils.compress(attributes.get(j).getIndex(), baggedEnsemble); if (learningRate != 1) { func.multiply(learningRate); } boostedEnsemble.add(func); baggedEnsemble = null; // Update residuals for (int i = 0; i < rTrain.length; i++) { Instance instance = trainSet.get(i); double pred = func.regress(instance); rTrain[i] -= pred; } for (int i = 0; i < rValid.length; i++) { Instance instance = validSet.get(i); double pred = func.regress(instance); pValid[i] += pred; rValid[i] -= pred; } double measure = metric.eval(pValid, validSet); ct.add(measure); if (verbose) { System.out.println("Iteration " + iter + " Feature " + j + ": " + measure); } } if (ct.isConverged()) { break; } } // Search the best model on validation set int idx = ct.getBestIndex(); // Prune tree ensembles int n = idx / attributes.size(); int m = idx % attributes.size(); for (int k = 0; k < regressors.size(); k++) { BoostedEnsemble boostedEnsemble = regressors.get(k); for (int i = boostedEnsemble.size(); i > n + 1; i--) { boostedEnsemble.removeLast(); } if (k > m) { boostedEnsemble.removeLast(); } } // Restore targets for (int i = 0; i < target.length; i++) { trainSet.get(i).setTarget(target[i]); } // Compress model for (int i = 0; i < regressors.size(); i++) { BoostedEnsemble boostedEnsemble = regressors.get(i); Attribute attribute = attributes.get(i); int attIndex = attribute.getIndex(); Function1D function = CompressionUtils.compress(attIndex, boostedEnsemble); Regressor regressor = function; if (attribute.getType() == Type.BINNED) { int l = ((BinnedAttribute) attribute).getNumBins(); regressor = CompressionUtils.convert(l, function); } else if (attribute.getType() == Type.NOMINAL) { int l = ((NominalAttribute) attribute).getCardinality(); regressor = CompressionUtils.convert(l, function); } gam.add(new int[] { attIndex }, regressor); } return gam; } /** * Builds a regressor. * * @param trainSet the training set. * @param maxNumIters the maximum number of iterations. * @param maxNumLeaves the maximum number of leaves. * @return a regressor. */ public GAM buildRegressor(Instances trainSet, int maxNumIters, int maxNumLeaves) { GAM gam = new GAM(); SimpleMetric simpleMetric = (SimpleMetric) metric; List attributes = trainSet.getAttributes(); List regressors = new ArrayList<>(attributes.size()); for (int i = 0; i < attributes.size(); i++) { regressors.add(new BoostedEnsemble()); } // Backup targets double[] target = new double[trainSet.size()]; for (int i = 0; i < target.length; i++) { target[i] = trainSet.get(i).getTarget(); } // Create bags EnsembledLineCutter elc = null; if (0 <= alpha & alpha <= 1) { SubaggedLineCutter slc = new SubaggedLineCutter(true); slc.createSubags(trainSet.size(), alpha, baggingIters); elc = slc; } else { BaggedLineCutter blc = new BaggedLineCutter(true); blc.createBags(trainSet.size(), baggingIters); elc = blc; } elc.setNumIntervals(maxNumLeaves); // Initialize predictions and residuals double[] pTrain = new double[trainSet.size()]; double[] rTrain = new double[trainSet.size()]; for (int i = 0; i < trainSet.size(); i++) { Instance instance = trainSet.get(i); rTrain[i] = instance.getTarget(); } // Gradient boosting for (int iter = 0; iter < maxNumIters; iter++) { for (int j = 0; j < attributes.size(); j++) { // Derivative to attribute k // Equivalent to residual BoostedEnsemble boostedEnsemble = regressors.get(j); // Prepare training set for (int i = 0; i < rTrain.length; i++) { trainSet.get(i).setTarget(rTrain[i]); } // Train model elc.setAttributeIndex(j); BaggedEnsemble baggedEnsemble = elc.build(trainSet); Function1D func = CompressionUtils.compress(attributes.get(j).getIndex(), baggedEnsemble); if (learningRate != 1) { func.multiply(learningRate); } boostedEnsemble.add(func); baggedEnsemble = null; // Update residuals for (int i = 0; i < rTrain.length; i++) { Instance instance = trainSet.get(i); double pred = func.regress(instance); pTrain[i] += pred; rTrain[i] -= pred; } double measure = simpleMetric.eval(pTrain, target); if (verbose) { System.out.println("Iteration " + iter + " Feature " + j + ": " + measure); } } } // Restore targets for (int i = 0; i < target.length; i++) { trainSet.get(i).setTarget(target[i]); } // Compress model for (int i = 0; i < regressors.size(); i++) { BoostedEnsemble boostedEnsemble = regressors.get(i); Attribute attribute = attributes.get(i); int attIndex = attribute.getIndex(); Function1D function = CompressionUtils.compress(attIndex, boostedEnsemble); Regressor regressor = function; if (attribute.getType() == Type.BINNED) { int l = ((BinnedAttribute) attribute).getNumBins(); regressor = CompressionUtils.convert(l, function); } else if (attribute.getType() == Type.NOMINAL) { int l = ((NominalAttribute) attribute).getCardinality(); regressor = CompressionUtils.convert(l, function); } gam.add(new int[] { attIndex }, regressor); } return gam; } @Override public GAM build(Instances instances) { GAM gam = null; if (maxNumIters < 0) { maxNumIters = 20; } if (metric == null) { metric = task.getDefaultMetric(); } switch (task) { case REGRESSION: if (validSet != null) { gam = buildRegressor(instances, validSet, maxNumIters, maxNumLeaves); } else { gam = buildRegressor(instances, maxNumIters, maxNumLeaves); } break; case CLASSIFICATION: if (validSet != null) { gam = buildClassifier(instances, validSet, maxNumIters, maxNumLeaves); } else { gam = buildClassifier(instances, maxNumIters, maxNumLeaves); } break; default: break; } return gam; } } ================================================ FILE: src/main/java/mltk/predictor/gam/GAMUtils.java ================================================ package mltk.predictor.gam; import java.util.List; import mltk.core.Attribute; import mltk.core.BinnedAttribute; import mltk.core.NominalAttribute; import mltk.core.NumericalAttribute; import mltk.predictor.function.Array1D; import mltk.predictor.function.LinearFunction; import mltk.predictor.glm.GLM; class GAMUtils { static GAM getGAM(GLM glm, List attList) { double[] w = glm.coefficients(0); GAM gam = new GAM(); int k = 0; for (Attribute attribute : attList) { int attIndex = attribute.getIndex(); int[] term = new int[] {attIndex}; if (attribute instanceof NumericalAttribute) { LinearFunction func = new LinearFunction(attIndex, -w[k++]); gam.add(term, func); } else if (attribute instanceof BinnedAttribute) { BinnedAttribute binnedAttribute = (BinnedAttribute) attribute; int size = binnedAttribute.getNumBins(); double[] predictions = new double[size]; for (int j = 0; j < predictions.length && k < w.length; j++) { predictions[j] = -w[k++]; } Array1D ary = new Array1D(attIndex, predictions); gam.add(term, ary); } else if (attribute instanceof NominalAttribute) { NominalAttribute nominalAttribute = (NominalAttribute) attribute; int size = nominalAttribute.getCardinality(); double[] predictions = new double[size]; for (int j = 0; j < predictions.length && k < w.length; j++) { predictions[j] = -w[k++]; } Array1D ary = new Array1D(attIndex, predictions); gam.add(term, ary); } } gam.setIntercept(-glm.intercept(0)); return gam; } } ================================================ FILE: src/main/java/mltk/predictor/gam/SPLAMLearner.java ================================================ package mltk.predictor.gam; import java.util.ArrayList; import java.util.Arrays; import java.util.HashMap; import java.util.List; import java.util.Map; import mltk.cmdline.Argument; import mltk.cmdline.CmdLineParser; import mltk.cmdline.options.LearnerWithTaskOptions; import mltk.core.Instances; import mltk.core.io.InstancesReader; import mltk.predictor.Learner; import mltk.predictor.Regressor; import mltk.predictor.function.CubicSpline; import mltk.predictor.function.LinearFunction; import mltk.predictor.glm.GLM; import mltk.predictor.glm.RidgeLearner; import mltk.predictor.io.PredictorWriter; import mltk.util.ArrayUtils; import mltk.util.MathUtils; import mltk.util.OptimUtils; import mltk.util.StatUtils; import mltk.util.VectorUtils; /** * Class for learning SPLAM models. Currently only cubic spline basis is supported. * * @author Yin Lou * */ public class SPLAMLearner extends Learner { static class Options extends LearnerWithTaskOptions { @Argument(name = "-d", description = "number of knots (default: 10)") int numKnots = 10; @Argument(name = "-m", description = "maximum number of iterations (default: 0)") int maxNumIters = 0; @Argument(name = "-l", description = "lambda (default: 0)") double lambda = 0; @Argument(name = "-a", description = "alpha (default: 1, i.e., SPAM model)") double alpha = 1; @Argument(name = "-L", description = "whether to compute lambda max for a given a") boolean lambdaMax = false; } /** * Trains a SPLAM. * *
	 * Usage: mltk.predictor.gam.SPLAMLearner
	 * -t	train set path
	 * [-g]	task between classification (c) and regression (r) (default: r)
	 * [-r]	attribute file path
	 * [-o]	output model path
	 * [-V]	verbose (default: true)
	 * [-d]	number of knots (default: 10)
	 * [-m]	maximum number of iterations (default: 0)
	 * [-l]	lambda (default: 0)
	 * [-a]	alpha (default: 1, i.e., SPAM model)
	 * 
* * @param args the command line arguments. * @throws Exception */ public static void main(String[] args) throws Exception { Options opts = new Options(); CmdLineParser parser = new CmdLineParser(SPLAMLearner.class, opts); Task task = null; try { parser.parse(args); task = Task.get(opts.task); if (opts.numKnots < 0) { throw new IllegalArgumentException("Number of knots must be positive."); } } catch (IllegalArgumentException e) { parser.printUsage(); System.exit(1); } Instances trainSet = InstancesReader.read(opts.attPath, opts.trainPath); SPLAMLearner learner = new SPLAMLearner(); learner.setNumKnots(opts.numKnots); learner.setMaxNumIters(opts.maxNumIters); learner.setLambda(opts.lambda); learner.setAlpha(opts.alpha); learner.setTask(task); learner.setVerbose(opts.verbose); if (opts.lambdaMax) { System.out.println(learner.findMaxLambda(trainSet, task, opts.numKnots, opts.alpha)); System.exit(0); } long start = System.currentTimeMillis(); GAM gam = learner.build(trainSet); long end = System.currentTimeMillis(); System.out.println("Time: " + (end - start) / 1000.0); if (opts.outputModelPath != null) { PredictorWriter.write(gam, opts.outputModelPath); } } static class ModelStructure { static final byte ELIMINATED = 0; static final byte LINEAR = 1; static final byte NONLINEAR = 2; byte[] structure; ModelStructure(byte[] structure) { this.structure = structure; } @Override public boolean equals(Object obj) { if (this == obj) return true; if (obj == null) return false; if (getClass() != obj.getClass()) return false; ModelStructure other = (ModelStructure) obj; if (!Arrays.equals(structure, other.structure)) return false; return true; } @Override public int hashCode() { final int prime = 31; int result = 1; result = prime * result + Arrays.hashCode(structure); return result; } } private boolean fitIntercept; private boolean refit; private int numKnots; private int maxNumIters; private Task task; private double lambda; private double alpha; private double epsilon; /** * Constructor. */ public SPLAMLearner() { verbose = false; fitIntercept = true; refit = false; numKnots = 10; maxNumIters = -1; lambda = 0.0; alpha = 1; epsilon = MathUtils.EPSILON; task = Task.REGRESSION; } @Override public GAM build(Instances instances) { GAM gam = null; if (maxNumIters < 0) { maxNumIters = 20; } if (numKnots < 0) { numKnots = 10; } switch (task) { case REGRESSION: gam = buildRegressor(instances, maxNumIters, numKnots, lambda, alpha); break; case CLASSIFICATION: gam = buildClassifier(instances, maxNumIters, numKnots, lambda, alpha); break; default: break; } return gam; } /** * Returns a binary classifier. * * @param attrs the attribute list. * @param x the inputs. * @param y the targets. * @param knots the knots. * @param maxNumIters the maximum number of iterations. * @param lambda the lambda. * @param alpha the alpha. * @return a binary classifier. */ public GAM buildBinaryClassifier(int[] attrs, double[][][] x, double[] y, double[][] knots, int maxNumIters, double lambda, double alpha) { double[][] w = new double[attrs.length][]; int m = 0; for (int j = 0; j < attrs.length; j++) { w[j] = new double[x[j].length]; if (w[j].length > m) { m = w[j].length; } } double[] tl1 = new double[attrs.length]; double[] tl2 = new double[attrs.length]; getRegularizationParameters(lambda, alpha, tl1, tl2, y.length); double[] pTrain = new double[y.length]; double[] rTrain = new double[y.length]; OptimUtils.computePseudoResidual(pTrain, y, rTrain); double[] stepSize = new double[attrs.length]; for (int j = 0; j < stepSize.length; j++) { double max = 0; double[][] block = x[j]; for (double[] t : block) { double l = StatUtils.sumSq(t) / 4; if (l > max) { max = l; } } stepSize[j] = 1.0 / max; } double[] g = new double[m]; double[] gradient = new double[m]; double[] gamma1 = new double[m]; double[] gamma2 = new double[m - 1]; boolean[] activeSet = new boolean[attrs.length]; double intercept = 0; // Block coordinate gradient descent int iter = 0; while (iter < maxNumIters) { if (fitIntercept) { intercept += OptimUtils.fitIntercept(pTrain, rTrain, y); } boolean activeSetChanged = doOnePass(x, y, tl1, tl2, true, activeSet, w, stepSize, g, gradient, gamma1, gamma2, pTrain, rTrain); iter++; if (!activeSetChanged || iter > maxNumIters) { break; } for (; iter < maxNumIters; iter++) { double prevLoss = OptimUtils.computeLogisticLoss(pTrain, y) + getPenalty(w, tl1, tl2); if (fitIntercept) { intercept += OptimUtils.fitIntercept(pTrain, rTrain, y); } doOnePass(x, y, tl1, tl2, false, activeSet, w, stepSize, g, gradient, gamma1, gamma2, pTrain, rTrain); double currLoss = OptimUtils.computeLogisticLoss(pTrain, y) + getPenalty(w, tl1, tl2); if (OptimUtils.isConverged(prevLoss, currLoss, epsilon)) { break; } if (verbose) { System.out.println("Iteration " + iter + ": " + currLoss); } } } if (refit) { byte[] structure = extractStructure(w); GAM gam = refitClassifier(attrs, structure, x, y, knots, w, maxNumIters); return gam; } else { return getGAM(attrs, knots, w, intercept); } } /** * Returns a binary classifier. * * @param attrs the attribute list. * @param indices the indices. * @param values the values. * @param y the targets. * @param knots the knots. * @param maxNumIters the maximum number of iterations. * @param lambda the lambda. * @param alpha the alpha. * @return a binary classifier. */ public GAM buildBinaryClassifier(int[] attrs, int[][] indices, double[][][] values, double[] y, double[][] knots, int maxNumIters, double lambda, double alpha) { double[][] w = new double[attrs.length][]; int m = 0; for (int j = 0; j < attrs.length; j++) { w[j] = new double[values[j].length]; if (w[j].length > m) { m = w[j].length; } } double[] tl1 = new double[attrs.length]; double[] tl2 = new double[attrs.length]; getRegularizationParameters(lambda, alpha, tl1, tl2, y.length); double[] pTrain = new double[y.length]; double[] rTrain = new double[y.length]; OptimUtils.computePseudoResidual(pTrain, y, rTrain); double[] stepSize = new double[attrs.length]; for (int j = 0; j < stepSize.length; j++) { double max = 0; double[][] block = values[j]; for (double[] t : block) { double l = StatUtils.sumSq(t) / 4; if (l > max) { max = l; } } stepSize[j] = 1.0 / max; } double[] g = new double[m]; double[] gradient = new double[m]; double[] gamma1 = new double[m]; double[] gamma2 = new double[m - 1]; boolean[] activeSet = new boolean[attrs.length]; double intercept = 0; // Block coordinate gradient descent int iter = 0; while (iter < maxNumIters) { if (fitIntercept) { intercept += OptimUtils.fitIntercept(pTrain, y, rTrain); } boolean activeSetChanged = doOnePass(indices, values, y, tl1, tl2, true, activeSet, w, stepSize, g, gradient, gamma1, gamma2, pTrain, rTrain); iter++; if (!activeSetChanged || iter > maxNumIters) { break; } for (; iter < maxNumIters; iter++) { double prevLoss = OptimUtils.computeLogisticLoss(pTrain, y) + getPenalty(w, tl1, tl2); if (fitIntercept) { intercept += OptimUtils.fitIntercept(pTrain, rTrain, y); } doOnePass(indices, values, y, tl1, tl2, false, activeSet, w, stepSize, g, gradient, gamma1, gamma2, pTrain, rTrain); double currLoss = OptimUtils.computeLogisticLoss(pTrain, y) + getPenalty(w, tl1, tl2); if (OptimUtils.isConverged(prevLoss, currLoss, epsilon)) { break; } if (verbose) { System.out.println("Iteration " + iter + ": " + currLoss); } } } if (refit) { byte[] structure = extractStructure(w); GAM gam = refitClassifier(attrs, structure, indices, values, y, knots, w, maxNumIters * 10); return gam; } else { return getGAM(attrs, knots, w, intercept); } } /** * Builds a classifier. * * @param trainSet the training set. * @param isSparse {@code true} if the training set is treated as sparse. * @param numKnots the number of knots. * @param maxNumIters the maximum number of iterations. * @param lambda the lambda. * @param alpha the alpha. * @return a classifier. */ public GAM buildClassifier(Instances trainSet, boolean isSparse, int numKnots, int maxNumIters, double lambda, double alpha) { if (isSparse) { SparseDataset sd = getSparseDataset(trainSet, false); SparseDesignMatrix sm = SparseDesignMatrix.createCubicSplineDesignMatrix(trainSet.size(), sd.indices, sd.values, sd.stdList, numKnots); double[] y = sd.y; int[][] indices = sm.indices; double[][][] values = sm.values; double[][] knots = sm.knots; int[] attrs = sd.attrs; // Mapping from attribute index to index in design matrix Map map = new HashMap<>(); for (int j = 0; j < sd.attrs.length; j++) { map.put(attrs[j], j); } GAM gam = buildBinaryClassifier(attrs, indices, values, y, knots, maxNumIters, lambda, alpha); // Rescale weights in gam List regressors = gam.getRegressors(); List terms = gam.getTerms(); double intercept = gam.getIntercept(); for (int i = 0; i < regressors.size(); i++) { Regressor regressor = regressors.get(i); int attIndex = terms.get(i)[0]; int idx = map.get(attIndex); double[] std = sm.std[idx]; if (regressor instanceof LinearFunction) { LinearFunction func = (LinearFunction) regressor; func.setSlope(func.getSlope() / std[0]); } else if (regressor instanceof CubicSpline) { CubicSpline spline = (CubicSpline) regressor; double[] w = spline.getCoefficients(); for (int j = 0; j < w.length; j++) { w[j] /= std[j]; } double[] k = spline.getKnots(); for (int j = 0; j < k.length; j++) { intercept -= w[j + 3] * CubicSpline.h(0, k[j]); } } } if (fitIntercept) { gam.setIntercept(intercept); } return gam; } else { DenseDataset dd = getDenseDataset(trainSet, false); DenseDesignMatrix dm = DenseDesignMatrix.createCubicSplineDesignMatrix(dd.x, dd.stdList, numKnots); double[] y = dd.y; double[][][] x = dm.x; double[][] knots = dm.knots; int[] attrs = dd.attrs; // Mapping from attribute index to index in design matrix Map map = new HashMap<>(); for (int j = 0; j < dd.attrs.length; j++) { map.put(dd.attrs[j], j); } GAM gam = buildBinaryClassifier(attrs, x, y, knots, maxNumIters, lambda, alpha); // Rescale weights in gam List regressors = gam.getRegressors(); List terms = gam.getTerms(); for (int i = 0; i < regressors.size(); i++) { Regressor regressor = regressors.get(i); int attIndex = terms.get(i)[0]; int idx = map.get(attIndex); double[] std = dm.std[idx]; if (regressor instanceof LinearFunction) { LinearFunction func = (LinearFunction) regressor; func.setSlope(func.getSlope() / std[0]); } else if (regressor instanceof CubicSpline) { CubicSpline spline = (CubicSpline) regressor; double[] w = spline.getCoefficients(); for (int j = 0; j < w.length; j++) { w[j] /= std[j]; } } } return gam; } } /** * Builds a classifier. * * @param trainSet the training set. * @param maxNumIters the maximum number of iterations. * @param numKnots the number of knots. * @param lambda the lambda. * @param alpha the alpha. * @return a classifier. */ public GAM buildClassifier(Instances trainSet, int maxNumIters, int numKnots, double lambda, double alpha) { return buildClassifier(trainSet, isSparse(trainSet), maxNumIters, numKnots, lambda, alpha); } /** * Builds a regressor. * * @param trainSet the training set. * @param isSparse {@code true} if the training set is treated as sparse. * @param maxNumIters the maximum number of iterations. * @param numKnots the number of knots. * @param lambda the lambda. * @param alpha the alpha. * @return a regressor. */ public GAM buildRegressor(Instances trainSet, boolean isSparse, int maxNumIters, int numKnots, double lambda, double alpha) { if (isSparse) { SparseDataset sd = getSparseDataset(trainSet, false); SparseDesignMatrix sm = SparseDesignMatrix.createCubicSplineDesignMatrix(trainSet.size(), sd.indices, sd.values, sd.stdList, numKnots); double[] y = sd.y; int[][] indices = sm.indices; double[][][] values = sm.values; double[][] knots = sm.knots; int[] attrs = sd.attrs; // Mapping from attribute index to index in design matrix Map map = new HashMap<>(); for (int j = 0; j < sd.attrs.length; j++) { map.put(attrs[j], j); } GAM gam = buildRegressor(attrs, indices, values, y, knots, maxNumIters, lambda, alpha); // Rescale weights in gam List regressors = gam.getRegressors(); List terms = gam.getTerms(); double intercept = gam.getIntercept(); for (int i = 0; i < regressors.size(); i++) { Regressor regressor = regressors.get(i); int attIndex = terms.get(i)[0]; int idx = map.get(attIndex); double[] std = sm.std[idx]; if (regressor instanceof LinearFunction) { LinearFunction func = (LinearFunction) regressor; func.setSlope(func.getSlope() / std[0]); } else if (regressor instanceof CubicSpline) { CubicSpline spline = (CubicSpline) regressor; double[] w = spline.getCoefficients(); for (int j = 0; j < w.length; j++) { w[j] /= std[j]; } double[] k = spline.getKnots(); for (int j = 0; j < k.length; j++) { intercept -= w[j + 3] * CubicSpline.h(0, k[j]); } } } if (fitIntercept) { gam.setIntercept(intercept); } return gam; } else { DenseDataset dd = getDenseDataset(trainSet, false); DenseDesignMatrix dm = DenseDesignMatrix.createCubicSplineDesignMatrix(dd.x, dd.stdList, numKnots); double[] y = dd.y; double[][][] x = dm.x; double[][] knots = dm.knots; int[] attrs = dd.attrs; // Mapping from attribute index to index in design matrix Map map = new HashMap<>(); for (int j = 0; j < dd.attrs.length; j++) { map.put(dd.attrs[j], j); } GAM gam = buildRegressor(attrs, x, y, knots, maxNumIters, lambda, alpha); // Rescale weights in gam List regressors = gam.getRegressors(); List terms = gam.getTerms(); for (int i = 0; i < regressors.size(); i++) { Regressor regressor = regressors.get(i); int attIndex = terms.get(i)[0]; int idx = map.get(attIndex); double[] std = dm.std[idx]; if (regressor instanceof LinearFunction) { LinearFunction func = (LinearFunction) regressor; func.setSlope(func.getSlope() / std[0]); } else if (regressor instanceof CubicSpline) { CubicSpline spline = (CubicSpline) regressor; double[] w = spline.getCoefficients(); for (int j = 0; j < w.length; j++) { w[j] /= std[j]; } } } return gam; } } /** * Builds a regressor. * * @param trainSet the training set. * @param maxNumIters the maximum number of iterations. * @param numKnots the number of knots. * @param lambda the lambda. * @param alpha the alpha. * @return a regressor. */ public GAM buildRegressor(Instances trainSet, int maxNumIters, int numKnots, double lambda, double alpha) { return buildRegressor(trainSet, isSparse(trainSet), maxNumIters, numKnots, lambda, alpha); } /** * Returns a regressor. * * @param attrs the attribute list. * @param x the inputs. * @param y the targets. * @param knots the knots. * @param maxNumIters the maximum number of iterations. * @param lambda the lambda. * @param alpha the alpha. * @return a regressor. */ public GAM buildRegressor(int[] attrs, double[][][] x, double[] y, double[][] knots, int maxNumIters, double lambda, double alpha) { // Backup targets double[] rTrain = new double[y.length]; for (int i = 0; i < rTrain.length; i++) { rTrain[i] = y[i]; } double[][] w = new double[attrs.length][]; int m = 0; for (int j = 0; j < attrs.length; j++) { w[j] = new double[x[j].length]; if (w[j].length > m) { m = w[j].length; } } double[] tl1 = new double[attrs.length]; double[] tl2 = new double[attrs.length]; getRegularizationParameters(lambda, alpha, tl1, tl2, y.length); double[] stepSize = new double[attrs.length]; for (int j = 0; j < stepSize.length; j++) { double max = 0; double[][] block = x[j]; for (double[] t : block) { double l = StatUtils.sumSq(t); if (l > max) { max = l; } } stepSize[j] = 1.0 / max; } double[] g = new double[m]; double[] gradient = new double[m]; double[] gamma1 = new double[m]; double[] gamma2 = new double[m - 1]; boolean[] activeSet = new boolean[attrs.length]; double intercept = 0; // Block coordinate gradient descent int iter = 0; while (iter < maxNumIters) { if (fitIntercept) { intercept += OptimUtils.fitIntercept(rTrain); } boolean activeSetChanged = doOnePass(x, tl1, tl2, true, activeSet, w, stepSize, g, gradient, gamma1, gamma2, rTrain); iter++; if (!activeSetChanged || iter > maxNumIters) { break; } for (; iter < maxNumIters; iter++) { double prevLoss = OptimUtils.computeQuadraticLoss(rTrain) + getPenalty(w, tl1, tl2); if (fitIntercept) { intercept += OptimUtils.fitIntercept(rTrain); } doOnePass(x, tl1, tl2, false, activeSet, w, stepSize, g, gradient, gamma1, gamma2, rTrain); double currLoss = OptimUtils.computeQuadraticLoss(rTrain) + getPenalty(w, tl1, tl2); if (OptimUtils.isConverged(prevLoss, currLoss, epsilon)) { break; } if (verbose) { System.out.println("Iteration " + iter + ": " + currLoss); } } } if (refit) { byte[] struct = extractStructure(w); return refitRegressor(attrs, struct, x, y, knots, w, maxNumIters); } else { return getGAM(attrs, knots, w, intercept); } } /** * Returns a regressor. * * @param attrs the attribute list. * @param indices the indices. * @param values the values. * @param y the targets. * @param knots the knots. * @param maxNumIters the maximum number of iterations. * @param lambda the lambda. * @param alpha the alpha. * @return a regressor. */ public GAM buildRegressor(int[] attrs, int[][] indices, double[][][] values, double[] y, double[][] knots, int maxNumIters, double lambda, double alpha) { // Backup targets double[] rTrain = new double[y.length]; for (int i = 0; i < rTrain.length; i++) { rTrain[i] = y[i]; } double[][] w = new double[attrs.length][]; int m = 0; for (int j = 0; j < attrs.length; j++) { w[j] = new double[values[j].length]; if (w[j].length > m) { m = w[j].length; } } double[] tl1 = new double[attrs.length]; double[] tl2 = new double[attrs.length]; getRegularizationParameters(lambda, alpha, tl1, tl2, y.length); double[] stepSize = new double[attrs.length]; for (int j = 0; j < stepSize.length; j++) { double max = 0; int[] index = indices[j]; double[][] block = values[j]; for (double[] t : block) { double l = 0; for (int i = 0; i < index.length; i++) { l += y[index[j]] * t[i]; } if (l > max) { max = l; } } stepSize[j] = 1.0 / max; } double[] g = new double[m]; double[] gradient = new double[m]; double[] gamma1 = new double[m]; double[] gamma2 = new double[m - 1]; boolean[] activeSet = new boolean[attrs.length]; double intercept = 0; // Block coordinate gradient descent int iter = 0; while (iter < maxNumIters) { if (fitIntercept) { intercept += OptimUtils.fitIntercept(rTrain); } boolean activeSetChanged = doOnePass(indices, values, tl1, tl2, true, activeSet, w, stepSize, g, gradient, gamma1, gamma2, rTrain); iter++; if (!activeSetChanged || iter > maxNumIters) { break; } for (; iter < maxNumIters; iter++) { double prevLoss = OptimUtils.computeQuadraticLoss(rTrain) + getPenalty(w, tl1, tl2); if (fitIntercept) { intercept += OptimUtils.fitIntercept(rTrain); } doOnePass(indices, values, tl1, tl2, false, activeSet, w, stepSize, g, gradient, gamma1, gamma2, rTrain); double currLoss = OptimUtils.computeQuadraticLoss(rTrain) + getPenalty(w, tl1, tl2); if (OptimUtils.isConverged(prevLoss, currLoss, epsilon)) { break; } if (verbose) { System.out.println("Iteration " + iter + ": " + currLoss); } } } if (refit) { byte[] structure = extractStructure(w); GAM gam = refitRegressor(attrs, structure, indices, values, y, knots, w, maxNumIters * 10); return gam; } else { return getGAM(attrs, knots, w, intercept); } } /** * Returns {@code true} if we fit intercept. * * @return {@code true} if we fit intercept. */ public boolean fitIntercept() { return fitIntercept; } /** * Sets whether we fit intercept. * * @param fitIntercept whether we fit intercept. */ public void fitIntercept(boolean fitIntercept) { this.fitIntercept = fitIntercept; } /** * Returns the alpha. * * @return the alpha; */ public double getAlpha() { return alpha; } /** * Returns the convergence threshold epsilon. * * @return the convergence threshold epsilon. */ public double getEpsilon() { return epsilon; } /** * Returns the lambda. * * @return the lambda. */ public double getLambda() { return lambda; } /** * Returns the maximum number of iterations. * * @return the maximum number of iterations. */ public int getMaxNumIters() { return maxNumIters; } /** * Returns the number of knots. * * @return the number of knots. */ public int getNumKnots() { return numKnots; } /** * Returns the task of this learner. * * @return the task of this learner. */ public Task getTask() { return task; } /** * Returns {@code true} if we output something during the training. * * @return {@code true} if we output something during the training. */ public boolean isVerbose() { return verbose; } /** * Returns {@code true} if we refit the model. * * @return {@code true} if we refit the model. */ public boolean refit() { return refit; } /** * Sets whether we refit the model. * * @param refit {@code true} if we refit the model. */ public void refit(boolean refit) { this.refit = refit; } /** * Sets the alpha. * * @param alpha the alpha. */ public void setAlpha(double alpha) { this.alpha = alpha; } /** * Sets the convergence threshold epsilon. * * @param epsilon the convergence threshold epsilon. */ public void setEpsilon(double epsilon) { this.epsilon = epsilon; } /** * Sets the lambda. * * @param lambda the lambda. */ public void setLambda(double lambda) { this.lambda = lambda; } /** * Sets the maximum number of iterations. * * @param maxNumIters the maximum number of iterations. */ public void setMaxNumIters(int maxNumIters) { this.maxNumIters = maxNumIters; } /** * Sets the number of knots. * * @param numKnots the new number of knots. */ public void setNumKnots(int numKnots) { this.numKnots = numKnots; } /** * Sets the task of this learner. * * @param task the task of this learner. */ public void setTask(Task task) { this.task = task; } /** * Sets whether we output something during the training. * * @param verbose the switch if we output things during training. */ public void setVerbose(boolean verbose) { this.verbose = verbose; } public double findMaxLambda(Instances trainSet, Task task, int numKnots, double alpha) { DenseDataset dd = getDenseDataset(trainSet, false); DenseDesignMatrix dm = DenseDesignMatrix.createCubicSplineDesignMatrix(dd.x, dd.stdList, numKnots); double[] y = dd.y; double[][][] x = dm.x; int[] attrs = dd.attrs; double[][] w = new double[attrs.length][]; int m = 0; for (int j = 0; j < attrs.length; j++) { w[j] = new double[x[j].length]; if (w[j].length > m) { m = w[j].length; } } double[] tl1 = new double[attrs.length]; double[] tl2 = new double[attrs.length]; double[] g = new double[m]; double[] gradient = new double[m]; double[] gamma1 = new double[m]; double[] gamma2 = new double[m - 1]; if (task == Task.REGRESSION) { return findMaxLambda(x, y, alpha, tl1, tl2, w, g, gradient, gamma1, gamma2); } else { double[] pTrain = new double[y.length]; double[] rTrain = new double[y.length]; OptimUtils.computePseudoResidual(pTrain, y, rTrain); return findMaxLambda(x, y, pTrain, rTrain, alpha, tl1, tl2, w, g, gradient, gamma1, gamma2); } } protected double findMaxLambda(double[][][] x, double[] rTrain, double alpha, double[] tl1, double[] tl2, double[][] w, double[] g, double[] gradient, double[] gamma1, double[] gamma2) { double mean = 0; if (fitIntercept) { mean = OptimUtils.fitIntercept(rTrain); } double lHigh = 0; for (double[][] block : x) { computeGradient(block, rTrain, gradient); double t = Math.sqrt(StatUtils.sumSq(gradient, 0, block.length)) / Math.sqrt(block.length); if (t > lHigh) { lHigh = t; } } lHigh /= alpha; double lLow = 0; while (lHigh - lLow > MathUtils.EPSILON) { double lambda = (lHigh + lLow) / 2; for (int j = 0; j < x.length; j++) { tl1[j] = lambda * alpha * Math.sqrt(w[j].length); tl2[j] = lambda * (1 - alpha) * Math.sqrt(w[j].length - 1); } boolean isZeroPoint = testZeroPoint(x, rTrain, tl1, tl2, w, g, gradient, gamma1, gamma2); if (isZeroPoint) { lHigh = lambda; } else { lLow = lambda; } } if (fitIntercept) { VectorUtils.add(rTrain, mean); } return lHigh; } protected double findMaxLambda(double[][][] x, double[] y, double[] pTrain, double[] rTrain, double alpha, double[] tl1, double[] tl2, double[][] coefficients, double[] g, double[] gradient, double[] gamma1, double[] gamma2) { if (fitIntercept) { OptimUtils.fitIntercept(pTrain, rTrain, y); } double lHigh = 0; for (double[][] block : x) { computeGradient(block, rTrain, gradient); double t = Math.sqrt(StatUtils.sumSq(gradient, 0, block.length)) / Math.sqrt(block.length); if (t > lHigh) { lHigh = t; } } lHigh /= alpha; double lLow = 0; while (lHigh - lLow > MathUtils.EPSILON) { double lambda = (lHigh + lLow) / 2; for (int j = 0; j < x.length; j++) { tl1[j] = lambda * alpha * Math.sqrt(coefficients[j].length); tl2[j] = lambda * (1 - alpha) * Math.sqrt(coefficients[j].length - 1); } boolean isZeroPoint = testZeroPoint(x, y, pTrain, rTrain, tl1, tl2, coefficients, g, gradient, gamma1, gamma2); if (isZeroPoint) { lHigh = lambda; } else { lLow = lambda; } } if (fitIntercept) { Arrays.fill(pTrain, 0); } return lHigh; } protected boolean testZeroPoint(double[][][] x, double[] y, double[] tl1, double[] tl2, double[][] w, double[] g, double[] gradient, double[] gamma1, double[] gamma2) { for (int k = 0; k < x.length; k++) { double[][] block = x[k]; final double lambda1 = tl1[k]; final double lambda2 = tl2[k]; // Proximal gradient method computeGradient(block, y, gradient); double[] beta = w[k]; for (int i = 0; i < beta.length; i++) { g[i] = gradient[i]; } // Dual method if (beta.length > 1) { for (int i = 1; i < beta.length; i++) { gamma2[i - 1] = g[i]; } double norm2 = VectorUtils.l2norm(gamma2); double t2 = lambda2; if (norm2 > t2) { VectorUtils.multiply(gamma2, t2 / norm2); } } gamma1[0] = g[0]; for (int i = 1; i < beta.length; i++) { gamma1[i] = g[i] - gamma2[i - 1]; } double norm1 = Math.sqrt(StatUtils.sumSq(gamma1, 0, beta.length)); double t1 = lambda1; if (norm1 > t1) { VectorUtils.multiply(gamma1, t1 / norm1); } g[0] -= gamma1[0]; for (int i = 1; i < beta.length; i++) { g[i] -= (gamma1[i] + gamma2[i - 1]); } if (!ArrayUtils.isConstant(g, 0, beta.length, 0)) { return false; } } return true; } protected boolean testZeroPoint(double[][][] x, double[] y, double[] pTrain, double[] rTrain, double[] tl1, double[] tl2, double[][] coefficients, double[] g, double[] gradient, double[] gamma1, double[] gamma2) { for (int k = 0; k < x.length; k++) { double[][] block = x[k]; final double lambda1 = tl1[k]; final double lambda2 = tl2[k]; // Proximal gradient method computeGradient(block, rTrain, gradient); double[] beta = coefficients[k]; for (int i = 0; i < beta.length; i++) { g[i] = gradient[i]; } // Dual method if (beta.length > 1) { for (int i = 1; i < beta.length; i++) { gamma2[i - 1] = g[i]; } double norm2 = VectorUtils.l2norm(gamma2); double t2 = lambda2; if (norm2 > t2) { VectorUtils.multiply(gamma2, t2 / norm2); } } gamma1[0] = g[0]; for (int i = 1; i < beta.length; i++) { gamma1[i] = g[i] - gamma2[i - 1]; } double norm1 = Math.sqrt(StatUtils.sumSq(gamma1, 0, beta.length)); double t1 = lambda1; if (norm1 > t1) { VectorUtils.multiply(gamma1, t1 / norm1); } g[0] -= gamma1[0]; for (int i = 1; i < beta.length; i++) { g[i] -= (gamma1[i] + gamma2[i - 1]); } if (!ArrayUtils.isConstant(g, 0, beta.length, 0)) { return false; } } return true; } protected void computeGradient(double[][] block, double[] rTrain, double[] gradient) { for (int i = 0; i < block.length; i++) { gradient[i] = VectorUtils.dotProduct(block[i], rTrain); } } protected void computeGradient(int[] index, double[][] block, double[] rTrain, double[] gradient) { for (int j = 0; j < block.length; j++) { double[] t = block[j]; gradient[j] = 0; for (int i = 0; i < t.length; i++) { gradient[j] += rTrain[index[i]] * t[i]; } } } protected boolean doOnePass(double[][][] x, double[] tl1, double[] tl2, boolean isFullPass, boolean[] activeSet, double[][] w, double[] stepSize, double[] g, double[] gradient, double[] gamma1, double[] gamma2, double[] rTrain) { boolean activeSetChanged = false; for (int k = 0; k < x.length; k++) { if (!isFullPass && !activeSet[k]) { continue; } double[][] block = x[k]; double[] beta = w[k]; double tk = stepSize[k]; double lambda1 = tl1[k]; double lambda2 = tl2[k]; // Proximal gradient method computeGradient(block, rTrain, gradient); for (int j = 0; j < beta.length; j++) { g[j] = beta[j] + tk * gradient[j]; } // Dual method if (beta.length > 1) { for (int i = 1; i < beta.length; i++) { gamma2[i - 1] = g[i]; } double norm2 = Math.sqrt(StatUtils.sumSq(gamma2, 0, beta.length - 1)); double t2 = lambda2 * tk; if (norm2 > t2) { VectorUtils.multiply(gamma2, t2 / norm2); } } gamma1[0] = g[0]; for (int i = 1; i < beta.length; i++) { gamma1[i] = g[i] - gamma2[i - 1]; } double norm1 = Math.sqrt(StatUtils.sumSq(gamma1, 0, beta.length)); double t1 = lambda1 * tk; if (norm1 > t1) { VectorUtils.multiply(gamma1, t1 / norm1); } g[0] -= gamma1[0]; for (int i = 1; i < beta.length; i++) { g[i] -= (gamma1[i] + gamma2[i - 1]); } // Update residuals for (int j = 0; j < beta.length; j++) { double[] t = block[j]; double delta = beta[j] - g[j]; for (int i = 0; i < rTrain.length; i++) { rTrain[i] += delta * t[i]; } } // Update weights for (int j = 0; j < beta.length; j++) { beta[j] = g[j]; } if (isFullPass && !activeSet[k] && !ArrayUtils.isConstant(beta, 0, beta.length, 0)) { activeSetChanged = true; activeSet[k] = true; } } return activeSetChanged; } protected boolean doOnePass(double[][][] x, double[] y, double[] tl1, double[] tl2, boolean isFullPass, boolean[] activeSet, double[][] w, double[] stepSize, double[] g, double[] gradient, double[] gamma1, double[] gamma2, double[] pTrain, double[] rTrain) { boolean activeSetChanged = false; for (int k = 0; k < x.length; k++) { if (!isFullPass && !activeSet[k]) { continue; } double[][] block = x[k]; double[] beta = w[k]; double tk = stepSize[k]; double lambda1 = tl1[k]; double lambda2 = tl2[k]; // Proximal gradient method computeGradient(block, rTrain, gradient); for (int j = 0; j < beta.length; j++) { g[j] = beta[j] + tk * gradient[j]; } // Dual method if (beta.length > 1) { for (int i = 1; i < beta.length; i++) { gamma2[i - 1] = g[i]; } double norm2 = Math.sqrt(StatUtils.sumSq(gamma2, 0, beta.length - 1)); double t2 = lambda2 * tk; if (norm2 > t2) { VectorUtils.multiply(gamma2, t2 / norm2); } } gamma1[0] = g[0]; for (int i = 1; i < beta.length; i++) { gamma1[i] = g[i] - gamma2[i - 1]; } double norm1 = Math.sqrt(StatUtils.sumSq(gamma1, 0, beta.length)); double t1 = lambda1 * tk; if (norm1 > t1) { VectorUtils.multiply(gamma1, t1 / norm1); } g[0] -= gamma1[0]; for (int i = 1; i < beta.length; i++) { g[i] -= (gamma1[i] + gamma2[i - 1]); } // Update predictions for (int j = 0; j < beta.length; j++) { double[] t = block[j]; double delta = g[j] - beta[j]; for (int i = 0; i < y.length; i++) { pTrain[i] += delta * t[i]; } } OptimUtils.computePseudoResidual(pTrain, y, rTrain); // Update weights for (int j = 0; j < beta.length; j++) { beta[j] = g[j]; } if (isFullPass && !activeSet[k] && !ArrayUtils.isConstant(beta, 0, beta.length, 0)) { activeSetChanged = true; activeSet[k] = true; } } return activeSetChanged; } protected boolean doOnePass(int[][] indices, double[][][] values, double[] tl1, double[] tl2, boolean isFullPass, boolean[] activeSet, double[][] w, double[] stepSize, double[] g, double[] gradient, double[] gamma1, double[] gamma2, double[] rTrain) { boolean activeSetChanged = false; for (int k = 0; k < values.length; k++) { if (!isFullPass && !activeSet[k]) { continue; } double[][] block = values[k]; int[] index = indices[k]; double[] beta = w[k]; double tk = stepSize[k]; double lambda1 = tl1[k]; double lambda2 = tl2[k]; // Proximal gradient method computeGradient(index, block, rTrain, gradient); for (int j = 0; j < beta.length; j++) { g[j] = beta[j] + tk * gradient[j]; } // Dual method if (beta.length > 1) { for (int i = 1; i < beta.length; i++) { gamma2[i - 1] = g[i]; } double norm2 = Math.sqrt(StatUtils.sumSq(gamma2, 0, beta.length - 1)); double t2 = lambda2 * tk; if (norm2 > t2) { VectorUtils.multiply(gamma2, t2 / norm2); } } gamma1[0] = g[0]; for (int i = 1; i < beta.length; i++) { gamma1[i] = g[i] - gamma2[i - 1]; } double norm1 = Math.sqrt(StatUtils.sumSq(gamma1, 0, beta.length)); double t1 = lambda1 * tk; if (norm1 > t1) { VectorUtils.multiply(gamma1, t1 / norm1); } g[0] -= gamma1[0]; for (int i = 1; i < beta.length; i++) { g[i] -= (gamma1[i] + gamma2[i - 1]); } // Update predictions for (int j = 0; j < beta.length; j++) { double[] t = block[j]; double delta = beta[j] - g[j]; for (int i = 0; i < t.length; i++) { rTrain[index[i]] += delta * t[i]; } } // Update weights for (int j = 0; j < beta.length; j++) { beta[j] = g[j]; } if (isFullPass && !activeSet[k] && !ArrayUtils.isConstant(beta, 0, beta.length, 0)) { activeSetChanged = true; activeSet[k] = true; } } return activeSetChanged; } protected boolean doOnePass(int[][] indices, double[][][] values, double[] y, double[] tl1, double[] tl2, boolean isFullPass, boolean[] activeSet, double[][] w, double[] stepSize, double[] g, double[] gradient, double[] gamma1, double[] gamma2, double[] pTrain, double[] rTrain) { boolean activeSetChanged = false; for (int k = 0; k < values.length; k++) { if (!isFullPass && !activeSet[k]) { continue; } int[] index = indices[k]; double[][] block = values[k]; double[] beta = w[k]; double tk = stepSize[k]; double lambda1 = tl1[k]; double lambda2 = tl2[k]; // Proximal gradient method computeGradient(index, block, rTrain, gradient); for (int j = 0; j < beta.length; j++) { g[j] = beta[j] + tk * gradient[j]; } // Dual method if (beta.length > 1) { for (int j = 1; j < beta.length; j++) { gamma2[j - 1] = g[j]; } double norm2 = Math.sqrt(StatUtils.sumSq(gamma2, 0, beta.length - 1)); double t2 = lambda2 * tk; if (norm2 > t2) { VectorUtils.multiply(gamma2, t2 / norm2); } } gamma1[0] = g[0]; for (int j = 1; j < beta.length; j++) { gamma1[j] = g[j] - gamma2[j - 1]; } double norm1 = Math.sqrt(StatUtils.sumSq(gamma1, 0, beta.length)); double t1 = lambda1 * tk; if (norm1 > t1) { VectorUtils.multiply(gamma1, t1 / norm1); } g[0] -= gamma1[0]; for (int i = 1; i < beta.length; i++) { g[i] -= (gamma1[i] + gamma2[i - 1]); } // Update predictions for (int j = 0; j < beta.length; j++) { double[] value = block[j]; double delta = g[j] - beta[j]; for (int i = 0; i < value.length; i++) { pTrain[index[i]] += delta * value[i]; } } for (int idx : index) { rTrain[idx] = OptimUtils.getPseudoResidual(pTrain[idx], y[idx]); } // Update weights for (int j = 0; j < beta.length; j++) { beta[j] = g[j]; } if (isFullPass && !activeSet[k] && !ArrayUtils.isConstant(beta, 0, beta.length, 0)) { activeSetChanged = true; activeSet[k] = true; } } return activeSetChanged; } protected byte[] extractStructure(double[][] w) { byte[] structure = new byte[w.length]; for (int i = 0; i < structure.length; i++) { double[] beta = w[i]; boolean isLinear = beta.length == 1 || ArrayUtils.isConstant(beta, 1, beta.length, 0); if (isLinear) { if (beta[0] != 0) { structure[i] = ModelStructure.LINEAR; } else { structure[i] = ModelStructure.ELIMINATED; } } else { structure[i] = ModelStructure.NONLINEAR; } } return structure; } protected GAM getGAM(int[] attrs, double[][] knots, double[][] w, double intercept) { GAM gam = new GAM(); for (int j = 0; j < attrs.length; j++) { int attIndex = attrs[j]; double[] beta = w[j]; boolean isLinear = beta.length == 1 || ArrayUtils.isConstant(beta, 1, beta.length, 0); if (isLinear) { if (beta[0] != 0) { // To rule out a feature, it has to be "linear" and 0 slope. gam.add(new int[] { attIndex }, new LinearFunction(attIndex, beta[0])); } } else { double[] coef = Arrays.copyOf(beta, beta.length); CubicSpline spline = new CubicSpline(attIndex, 0, knots[j], coef); gam.add(new int[] { attIndex }, spline); } } gam.setIntercept(intercept); return gam; } protected double getPenalty(double[] w, double lambda1, double lambda2) { double penalty = 0; double sumSq = StatUtils.sumSq(w); double norm1 = Math.sqrt(sumSq); penalty += lambda1 * norm1; double norm2 = sumSq - w[0] * w[0]; norm2 = Math.sqrt(norm2); penalty += lambda2 * norm2; return penalty; } protected double getPenalty(double[][] coef, double[] lambda1, double[] lambda2) { double penalty = 0; for (int i = 0; i < coef.length; i++) { penalty += getPenalty(coef[i], lambda1[i], lambda2[i]); } return penalty; } protected void getRegularizationParameters(double lambda, double alpha, double[] tl1, double[] tl2, int n) { for (int j = 0; j < tl1.length; j++) { tl1[j] = lambda * alpha * n; tl2[j] = lambda * (1 - alpha) * n; } } protected GAM refitClassifier(int[] attrs, byte[] struct, double[][][] x, double[] y, double[][] knots, double[][] w, int maxNumIters) { List xList = new ArrayList<>(); for (int i = 0; i < struct.length; i++) { if (struct[i] == ModelStructure.NONLINEAR) { double[][] t = x[i]; for (int j = 0; j < t.length; j++) { xList.add(t[j]); } } else if (struct[i] == ModelStructure.LINEAR) { xList.add(x[i][0]); } } double[][] xNew = new double[xList.size()][]; for (int i = 0; i < xNew.length; i++) { xNew[i] = xList.get(i); } int[] attrsNew = new int[xNew.length]; for (int i = 0; i < attrsNew.length; i++) { attrsNew[i] = i; } RidgeLearner ridgeLearner = new RidgeLearner(); ridgeLearner.setVerbose(verbose); ridgeLearner.setEpsilon(epsilon); ridgeLearner.fitIntercept(fitIntercept); // A ridge regression with very small regularization parameter // This often improves stability a lot GLM glm = ridgeLearner.buildBinaryClassifier(attrsNew, xNew, y, maxNumIters, 1e-8); GAM gam = getGAM(attrs, knots, w, glm.intercept(0)); List regressors = gam.regressors; double[] coef = glm.coefficients(0); int k = 0; for (Regressor regressor : regressors) { if (regressor instanceof LinearFunction) { LinearFunction func = (LinearFunction) regressor; func.setSlope(coef[k++]); } else if (regressor instanceof CubicSpline) { CubicSpline spline = (CubicSpline) regressor; double[] beta = spline.getCoefficients(); for (int i = 0; i < beta.length; i++) { beta[i] = coef[k++]; } } } return gam; } protected GAM refitClassifier(int[] attrs, byte[] struct, int[][] indices, double[][][] values, double[] y, double[][] knots, double[][] w, int maxNumIters) { List iList = new ArrayList<>(); List vList = new ArrayList<>(); for (int i = 0; i < struct.length; i++) { int[] index = indices[i]; if (struct[i] == ModelStructure.NONLINEAR) { double[][] t = values[i]; for (int j = 0; j < t.length; j++) { iList.add(index); vList.add(t[j]); } } else if (struct[i] == ModelStructure.LINEAR) { iList.add(index); vList.add(values[i][0]); } } int[][] iNew = new int[iList.size()][]; for (int i = 0; i < iNew.length; i++) { iNew[i] = iList.get(i); } double[][] vNew = new double[vList.size()][]; for (int i = 0; i < vNew.length; i++) { vNew[i] = vList.get(i); } int[] attrsNew = new int[iNew.length]; for (int i = 0; i < attrsNew.length; i++) { attrsNew[i] = i; } RidgeLearner ridgeLearner = new RidgeLearner(); ridgeLearner.setVerbose(verbose); ridgeLearner.setEpsilon(epsilon); ridgeLearner.fitIntercept(fitIntercept); // A ridge regression with very small regularization parameter // This often improves stability a lot GLM glm = ridgeLearner.buildBinaryClassifier(attrsNew, iNew, vNew, y, maxNumIters, 1e-8); GAM gam = getGAM(attrs, knots, w, glm.intercept(0)); List regressors = gam.regressors; double[] coef = glm.coefficients(0); int k = 0; for (Regressor regressor : regressors) { if (regressor instanceof LinearFunction) { LinearFunction func = (LinearFunction) regressor; func.setSlope(coef[k++]); } else if (regressor instanceof CubicSpline) { CubicSpline spline = (CubicSpline) regressor; double[] beta = spline.getCoefficients(); for (int j = 0; j < beta.length; j++) { beta[j] = coef[k++]; } } } return gam; } protected GAM refitRegressor(int[] attrs, byte[] struct, double[][][] x, double[] y, double[][] knots, double[][] w, int maxNumIters) { List xList = new ArrayList<>(); for (int i = 0; i < struct.length; i++) { if (struct[i] == ModelStructure.NONLINEAR) { double[][] t = x[i]; for (int j = 0; j < t.length; j++) { xList.add(t[j]); } } else if (struct[i] == ModelStructure.LINEAR) { xList.add(x[i][0]); } } if (xList.size() == 0) { if (fitIntercept) { double intercept = StatUtils.mean(y); GAM gam = new GAM(); gam.setIntercept(intercept); return gam; } else { return new GAM(); } } double[][] xNew = new double[xList.size()][]; for (int i = 0; i < xNew.length; i++) { xNew[i] = xList.get(i); } int[] attrsNew = new int[xNew.length]; for (int i = 0; i < attrsNew.length; i++) { attrsNew[i] = i; } RidgeLearner ridgeLearner = new RidgeLearner(); ridgeLearner.setVerbose(verbose); ridgeLearner.setEpsilon(epsilon); ridgeLearner.fitIntercept(fitIntercept); // A ridge regression with very small regularization parameter // This often improves stability a lot GLM glm = ridgeLearner.buildGaussianRegressor(attrsNew, xNew, y, maxNumIters, 1e-8); GAM gam = getGAM(attrs, knots, w, glm.intercept(0)); List regressors = gam.regressors; double[] coef = glm.coefficients(0); int k = 0; for (Regressor regressor : regressors) { if (regressor instanceof LinearFunction) { LinearFunction func = (LinearFunction) regressor; func.setSlope(coef[k++]); } else if (regressor instanceof CubicSpline) { CubicSpline spline = (CubicSpline) regressor; double[] beta = spline.getCoefficients(); for (int i = 0; i < beta.length; i++) { beta[i] = coef[k++]; } } } return gam; } protected GAM refitRegressor(int[] attrs, byte[] struct, int[][] indices, double[][][] values, double[] y, double[][] knots, double[][] w, int maxNumIters) { List iList = new ArrayList<>(); List vList = new ArrayList<>(); for (int i = 0; i < struct.length; i++) { int[] index = indices[i]; if (struct[i] == ModelStructure.NONLINEAR) { double[][] t = values[i]; for (int j = 0; j < t.length; j++) { iList.add(index); vList.add(t[j]); } } else if (struct[i] == ModelStructure.LINEAR) { iList.add(index); vList.add(values[i][0]); } } int[][] iNew = new int[iList.size()][]; for (int i = 0; i < iNew.length; i++) { iNew[i] = iList.get(i); } double[][] vNew = new double[vList.size()][]; for (int i = 0; i < vNew.length; i++) { vNew[i] = vList.get(i); } int[] attrsNew = new int[iNew.length]; for (int i = 0; i < attrsNew.length; i++) { attrsNew[i] = i; } RidgeLearner ridgeLearner = new RidgeLearner(); ridgeLearner.setVerbose(verbose); ridgeLearner.setEpsilon(epsilon); ridgeLearner.fitIntercept(fitIntercept); // A ridge regression with very small regularization parameter // This often improves stability a lot GLM glm = ridgeLearner.buildGaussianRegressor(attrsNew, iNew, vNew, y, maxNumIters, 1e-8); GAM gam = getGAM(attrs, knots, w, glm.intercept(0)); List regressors = gam.regressors; double[] coef = glm.coefficients(0); int k = 0; for (Regressor regressor : regressors) { if (regressor instanceof LinearFunction) { LinearFunction func = (LinearFunction) regressor; func.setSlope(coef[k++]); } else if (regressor instanceof CubicSpline) { CubicSpline spline = (CubicSpline) regressor; double[] beta = spline.getCoefficients(); for (int j = 0; j < beta.length; j++) { beta[j] = coef[k++]; } } } return gam; } } ================================================ FILE: src/main/java/mltk/predictor/gam/ScorecardModelLearner.java ================================================ package mltk.predictor.gam; import mltk.cmdline.Argument; import mltk.cmdline.CmdLineParser; import mltk.cmdline.options.LearnerWithTaskOptions; import mltk.core.Instances; import mltk.core.io.InstancesReader; import mltk.core.processor.OneHotEncoder; import mltk.predictor.Learner; import mltk.predictor.glm.GLM; import mltk.predictor.glm.RidgeLearner; import mltk.predictor.io.PredictorWriter; /** * Class for learning scorecard models. Scorecard models are a special kind of * generalized additive models where scores for each state of the feature are * learned. * * @author Yin Lou * */ public class ScorecardModelLearner extends Learner { static class Options extends LearnerWithTaskOptions { @Argument(name = "-m", description = "maximum number of iterations", required = true) int maxNumIters = -1; @Argument(name = "-l", description = "lambda (default: 0)") double lambda = 0; } /** * Trains a scorecard model. * *
	 * Usage: mltk.predictor.gam.ScorecardModelLearner
	 * -t	train set path
	 * -m	maximum number of iterations
	 * [-g]	task between classification (c) and regression (r) (default: r)
	 * [-r]	attribute file path
	 * [-o]	output model path
	 * [-V]	verbose (default: true)
	 * [-l]	lambda (default: 0)
	 * 
* * @param args the command line arguments. * @throws Exception */ public static void main(String[] args) throws Exception { Options opts = new Options(); CmdLineParser parser = new CmdLineParser(ScorecardModelLearner.class, opts); Task task = null; try { parser.parse(args); task = Task.get(opts.task); } catch (IllegalArgumentException e) { parser.printUsage(); System.exit(1); } Instances trainSet = InstancesReader.read(opts.attPath, opts.trainPath); ScorecardModelLearner learner = new ScorecardModelLearner(); learner.setMaxNumIters(opts.maxNumIters); learner.setLambda(opts.lambda); learner.setTask(task); learner.setVerbose(opts.verbose); long start = System.currentTimeMillis(); GAM gam = learner.build(trainSet); long end = System.currentTimeMillis(); System.out.println("Time: " + (end - start) / 1000.0); if (opts.outputModelPath != null) { PredictorWriter.write(gam, opts.outputModelPath); } } private int maxNumIters; private double lambda; private Task task; private OneHotEncoder encoder; /** * Constructor. */ public ScorecardModelLearner() { verbose = false; maxNumIters = -1; encoder = new OneHotEncoder(); lambda = 0; task = Task.REGRESSION; } /** * Returns the lambda. * * @return the lambda. */ public double getLambda() { return lambda; } /** * Sets the lambda. * * @param lambda the lambda. */ public void setLambda(double lambda) { this.lambda = lambda; } /** * Returns the maximum number of iterations. * * @return the maximum number of iterations. */ public int getMaxNumIters() { return maxNumIters; } /** * Sets the maximum number of iterations. * * @param maxNumIters the maximum number of iterations. */ public void setMaxNumIters(int maxNumIters) { this.maxNumIters = maxNumIters; } /** * Returns the task of this learner. * * @return the task of this learner. */ public Task getTask() { return task; } /** * Sets the task of this learner. * * @param task the task of this learner. */ public void setTask(Task task) { this.task = task; } /** * Builds a classifier. * * @param trainSet the training set. * @param maxNumIters the maximum number of iterations. * @param lambda the L2 regularization parameter. * @return a classifier. */ public GAM buildClassifier(Instances trainSet, int maxNumIters, double lambda) { Instances trainSetNew = encoder.process(trainSet); RidgeLearner learner = new RidgeLearner(); learner.setTask(Task.CLASSIFICATION); learner.setLambda(lambda); learner.setVerbose(verbose); learner.setMaxNumIters(maxNumIters); GLM glm = learner.build(trainSetNew); return GAMUtils.getGAM(glm, trainSet.getAttributes()); } /** * Builds a regressor. * * @param trainSet the training set. * @param maxNumIters the maximum number of iterations. * @param lambda the L2 regularization parameter. * @return a regressor. */ public GAM buildRegressor(Instances trainSet, int maxNumIters, double lambda) { Instances trainSetNew = encoder.process(trainSet); RidgeLearner learner = new RidgeLearner(); learner.setTask(Task.REGRESSION); learner.setLambda(lambda); learner.setVerbose(verbose); learner.setMaxNumIters(maxNumIters); GLM glm = learner.build(trainSetNew); return GAMUtils.getGAM(glm, trainSet.getAttributes()); } @Override public GAM build(Instances instances) { GAM gam = null; if (maxNumIters < 0) { maxNumIters = instances.dimension() * 20; } switch (task) { case REGRESSION: gam = buildRegressor(instances, maxNumIters, lambda); break; case CLASSIFICATION: gam = buildClassifier(instances, maxNumIters, lambda); break; default: break; } return gam; } } ================================================ FILE: src/main/java/mltk/predictor/gam/SparseDesignMatrix.java ================================================ package mltk.predictor.gam; import java.util.HashSet; import java.util.Set; import mltk.predictor.function.CubicSpline; import mltk.util.StatUtils; import mltk.util.VectorUtils; class SparseDesignMatrix { int[][] indices; double[][][] values; double[][] knots; double[][] std; SparseDesignMatrix(int[][] indices, double[][][] values, double[][] knots, double[][] std) { this.indices = indices; this.values = values; this.knots = knots; this.std = std; } static SparseDesignMatrix createCubicSplineDesignMatrix(int n, int[][] indices, double[][] values, double[] stdList, int numKnots) { final int p = indices.length; double[][][] x = new double[p][][]; double[][] knots = new double[p][]; double[][] std = new double[p][]; double factor = Math.sqrt(n); for (int j = 0; j < values.length; j++) { Set uniqueValues = new HashSet<>(); double[] x1 = values[j]; for (int i = 0; i < values[j].length; i++) { uniqueValues.add(x1[i]); } int nKnots = uniqueValues.size() + 1 <= numKnots ? 0 : numKnots; knots[j] = new double[nKnots]; if (nKnots != 0) { x[j] = new double[nKnots + 3][]; std[j] = new double[nKnots + 3]; } else { x[j] = new double[1][]; std[j] = new double[1]; } double[][] tX = x[j]; tX[0] = x1; std[j][0] = stdList[j] / factor; if (nKnots != 0) { double[] x2 = new double[x1.length]; for (int i = 0; i < x2.length; i++) { x2[i] = x1[i] * x1[i]; } tX[1] = x2; double[] x3 = new double[x1.length]; for (int i = 0; i < x3.length; i++) { x3[i] = x2[i] * x1[i]; } tX[2] = x3; std[j][1] = StatUtils.sd(x2, n) / factor; std[j][2] = StatUtils.sd(x3, n) / factor; double max = Math.max(StatUtils.max(x1), 0); double min = Math.min(StatUtils.min(x1), 0); double stepSize = (max - min) / nKnots; for (int k = 0; k < nKnots; k++) { knots[j][k] = min + stepSize * k; double[] basis = new double[x1.length]; double zero = CubicSpline.h(0, knots[j][k]); for (int i = 0; i < basis.length; i++) { basis[i] = CubicSpline.h(x1[i], knots[j][k]) - zero; } std[j][k + 3] = StatUtils.sd(basis, n) / factor; tX[k + 3] = basis; } } } // Normalize the inputs for (int j = 0; j < p; j++) { double[][] block = x[j]; double[] s = std[j]; for (int i = 0; i < block.length; i++) { VectorUtils.divide(block[i], s[i]); } } return new SparseDesignMatrix(indices, x, knots, std); } } ================================================ FILE: src/main/java/mltk/predictor/gam/interaction/FAST.java ================================================ package mltk.predictor.gam.interaction; import java.io.BufferedReader; import java.io.FileReader; import java.io.PrintWriter; import java.util.ArrayList; import java.util.Collections; import java.util.List; import mltk.cmdline.Argument; import mltk.cmdline.CmdLineParser; import mltk.core.Attribute; import mltk.core.BinnedAttribute; import mltk.core.Instance; import mltk.core.Instances; import mltk.core.NominalAttribute; import mltk.core.Attribute.Type; import mltk.core.io.InstancesReader; import mltk.core.processor.Discretizer; import mltk.predictor.function.CHistogram; import mltk.predictor.function.Histogram2D; import mltk.util.Element; import mltk.util.MathUtils; import mltk.util.tuple.IntPair; /** * Class for fast interaction detection. * *

* Reference:
* Y. Lou, R. Caruana, J. Gehrke, and G. Hooker. Accurate intelligible models with pairwise interactions. In * Proceedings of the 19th ACM SIGKDD International Conference on Knowledge Discovery and Data Mining (KDD), * Chicago, IL, USA, 2013. *

* * @author Yin Lou * */ public class FAST { static class FASTThread extends Thread { List> pairs; Instances instances; FASTThread(Instances instances) { this.instances = instances; this.pairs = new ArrayList<>(); } public void add(Element pair) { pairs.add(pair); } public void run() { FAST.computeWeights(instances, pairs); } } static class Options { @Argument(name = "-r", description = "attribute file path") String attPath = null; @Argument(name = "-d", description = "dataset path", required = true) String datasetPath = null; @Argument(name = "-R", description = "residual path", required = true) String residualPath = null; @Argument(name = "-o", description = "output path", required = true) String outputPath = null; @Argument(name = "-b", description = "number of bins (default: 256)") int maxNumBins = 256; @Argument(name = "-p", description = "number of threads (default: 1)") int numThreads = 1; } /** * Ranks pairwise interactions using FAST. * *
	 * Usage: mltk.predictor.gam.interaction.FAST
	 * -d	dataset path
	 * -R	residual path
	 * -o	output path
	 * [-r]	attribute file path
	 * [-b]	number of bins (default: 256)
	 * [-p]	number of threads (default: 1)
	 * 
* * @param args the command line arguments * @throws Exception */ public static void main(String[] args) throws Exception { Options opts = new Options(); CmdLineParser parser = new CmdLineParser(FAST.class, opts); try { parser.parse(args); } catch (IllegalArgumentException e) { parser.printUsage(); System.exit(1); } Instances instances = InstancesReader.read(opts.attPath, opts.datasetPath); System.out.println("Reading residuals..."); BufferedReader br = new BufferedReader(new FileReader(opts.residualPath), 65535); for (int i = 0; i < instances.size(); i++) { String line = br.readLine(); double residual = Double.parseDouble(line); Instance instance = instances.get(i); instance.setTarget(residual); } br.close(); List attributes = instances.getAttributes(); System.out.println("Discretizing attribute..."); for (int i = 0; i < attributes.size(); i++) { if (attributes.get(i).getType() == Type.NUMERIC) { Discretizer.discretize(instances, i, opts.maxNumBins); } } System.out.println("Generating all pairs of attributes..."); List> pairs = new ArrayList<>(); for (int i = 0; i < attributes.size(); i++) { for (int j = i + 1; j < attributes.size(); j++) { pairs.add(new Element(new IntPair(i, j), 0.0)); } } System.out.println("Creating threads..."); FASTThread[] threads = new FASTThread[opts.numThreads]; long start = System.currentTimeMillis(); for (int i = 0; i < threads.length; i++) { threads[i] = new FASTThread(instances); } for (int i = 0; i < pairs.size(); i++) { threads[i % threads.length].add(pairs.get(i)); } for (int i = 0; i < threads.length; i++) { threads[i].start(); } System.out.println("Running FAST..."); for (int i = 0; i < threads.length; i++) { threads[i].join(); } long end = System.currentTimeMillis(); System.out.println("Sorting pairs..."); Collections.sort(pairs); System.out.println("Time: " + (end - start) / 1000.0); PrintWriter out = new PrintWriter(opts.outputPath); for (int i = 0; i < pairs.size(); i++) { Element pair = pairs.get(i); out.println(pair.element.v1 + "\t" + pair.element.v2 + "\t" + pair.weight); } out.flush(); out.close(); } /** * Computes the weights of pairwise interactions. * * @param instances the training set. * @param pairs the list of pairs to compute. */ public static void computeWeights(Instances instances, List> pairs) { List attributes = instances.getAttributes(); boolean[] used = new boolean[attributes.size()]; for (Element pair : pairs) { int f1 = pair.element.v1; int f2 = pair.element.v2; used[f1] = used[f2] = true; } CHistogram[] cHist = new CHistogram[attributes.size()]; for (int i = 0; i < cHist.length; i++) { if (used[i]) { switch (attributes.get(i).getType()) { case BINNED: BinnedAttribute binnedAtt = (BinnedAttribute) attributes.get(i); cHist[i] = new CHistogram(binnedAtt.getNumBins()); break; case NOMINAL: NominalAttribute nominalAtt = (NominalAttribute) attributes.get(i); cHist[i] = new CHistogram(nominalAtt.getCardinality()); default: break; } } } double ySq = computeCHistograms(instances, used, cHist); for (Element pair : pairs) { final int f1 = pair.element.v1; final int f2 = pair.element.v2; final int size1 = cHist[f1].size(); final int size2 = cHist[f2].size(); Histogram2D hist2d = new Histogram2D(size1, size2); Histogram2D.computeHistogram2D(instances, f1, f2, hist2d); computeWeight(pair, cHist, hist2d, ySq); } } protected static double computeCHistograms(Instances instances, boolean[] used, CHistogram[] cHist) { double ySq = 0; // compute histogram for (Instance instance : instances) { double resp = instance.getTarget(); for (int j = 0; j < instances.getAttributes().size(); j++) { if (used[j]) { if (!instance.isMissing(j)) { int idx = (int) instance.getValue(j); cHist[j].sum[idx] += resp * instance.getWeight(); cHist[j].count[idx] += instance.getWeight(); } else { cHist[j].sumOnMV += resp * instance.getWeight(); cHist[j].countOnMV += instance.getWeight(); } } } ySq += resp * resp * instance.getWeight(); } // compute cumulative histogram for (int j = 0; j < cHist.length; j++) { if (used[j]) { for (int idx = 1; idx < cHist[j].size(); idx++) { cHist[j].sum[idx] += cHist[j].sum[idx - 1]; cHist[j].count[idx] += cHist[j].count[idx - 1]; } } } return ySq; } protected static void computeWeight(Element pair, CHistogram[] cHist, Histogram2D hist2d, double ySq) { final int f1 = pair.element.v1; final int f2 = pair.element.v2; final int size1 = cHist[f1].size(); final int size2 = cHist[f2].size(); Histogram2D.Table table = Histogram2D.computeTable(hist2d, cHist[f1], cHist[f2]); double bestRSS = Double.POSITIVE_INFINITY; double[] predInt = new double[4]; double[] predOnMV1 = new double[2]; double[] predOnMV2 = new double[2]; double predOnMV12 = MathUtils.divide(hist2d.respOnMV12, hist2d.countOnMV12, 0); for (int v1 = 0; v1 < size1 - 1; v1++) { for (int v2 = 0; v2 < size2 - 1; v2++) { getPredictor(table, v1, v2, predInt, predOnMV1, predOnMV2); double rss = getRSS(table, v1, v2, ySq, predInt, predOnMV1, predOnMV2, predOnMV12); if (rss < bestRSS) { bestRSS = rss; } } } pair.weight = bestRSS; } protected static void getPredictor(Histogram2D.Table table, int v1, int v2, double[] pred, double[] predOnMV1, double[] predOnMV2) { double[] count = table.count[v1][v2]; double[] resp = table.resp[v1][v2]; for (int i = 0; i < pred.length; i++) { pred[i] = MathUtils.divide(resp[i], count[i], 0); } for (int i = 0; i < predOnMV1.length; i++) { predOnMV1[i] = MathUtils.divide(table.respOnMV1[v2][i], table.countOnMV1[v2][i], 0); } for (int i = 0; i < predOnMV2.length; i++) { predOnMV2[i] = MathUtils. divide(table.respOnMV2[v1][i], table.countOnMV2[v1][i], 0); } } protected static double getRSS(Histogram2D.Table table, int v1, int v2, double ySq, double[] pred, double[] predOnMV1, double[] predOnMV2, double predOnMV12) { double[] count = table.count[v1][v2]; double[] resp = table.resp[v1][v2]; double[] respOnMV1 = table.respOnMV1[v2]; double[] countOnMV1 = table.countOnMV1[v2]; double[] respOnMV2 = table.respOnMV2[v1]; double[] countOnMV2 = table.countOnMV2[v1]; double rss = ySq; // Compute main area double t = 0; for (int i = 0; i < pred.length; i++) { t += pred[i] * pred[i] * count[i]; } rss += t; t = 0; for (int i = 0; i < pred.length; i++) { t += pred[i] * resp[i]; } rss -= 2 * t; // Compute on mv1 t = 0; for (int i = 0; i < predOnMV1.length; i++) { t += predOnMV1[i] * predOnMV1[i] * countOnMV1[i]; } rss += t; t = 0; for (int i = 0; i < predOnMV1.length; i++) { t += predOnMV1[i] * respOnMV1[i]; } rss -= 2 * t; // Compute on mv2 t = 0; for (int i = 0; i < predOnMV2.length; i++) { t += predOnMV2[i] * predOnMV2[i] * countOnMV2[i]; } rss += t; t = 0; for (int i = 0; i < predOnMV2.length; i++) { t += predOnMV2[i] * respOnMV2[i]; } rss -= 2 * t; return rss; } } ================================================ FILE: src/main/java/mltk/predictor/gam/interaction/package-info.java ================================================ /** * Provides algorithms for feature interaction detection. */ package mltk.predictor.gam.interaction; ================================================ FILE: src/main/java/mltk/predictor/gam/package-info.java ================================================ /** * Provides algorithms for fitting generalized additive models (GAMs). */ package mltk.predictor.gam; ================================================ FILE: src/main/java/mltk/predictor/gam/tool/Diagnostics.java ================================================ package mltk.predictor.gam.tool; import java.io.PrintWriter; import java.util.ArrayList; import java.util.Arrays; import java.util.Collections; import java.util.HashMap; import java.util.List; import java.util.Map; import mltk.cmdline.Argument; import mltk.cmdline.CmdLineParser; import mltk.core.Instance; import mltk.core.Instances; import mltk.core.io.InstancesReader; import mltk.predictor.Regressor; import mltk.predictor.gam.GAM; import mltk.predictor.io.PredictorReader; import mltk.util.StatUtils; import mltk.util.Element; /** * Class for GAM diagnostics. * * @author Yin Lou * */ public class Diagnostics { /** * Enumeration of methods for calculating term importance. * * @author Yin Lou * */ public enum Mode { /** * L1. */ L1("L1"), /** * PDF terminal. */ L2("L2"); String mode; Mode(String mode) { this.mode = mode; } public String toString() { return mode; } /** * Parses a mode from a string. * * @param mode the mode. * @return a parsed terminal. */ public static Mode getEnum(String mode) { for (Mode re : Mode.values()) { if (re.mode.compareTo(mode) == 0) { return re; } } throw new IllegalArgumentException("Invalid mode: " + mode); } } /** * Computes the weights for each term in a GAM. * * @param gam the GAM model. * @param instances the training set. * @return the list of weights for each term in a GAM. */ public static List> diagnose(GAM gam, Instances instances) { return diagnose(gam, instances, Mode.L2); } /** * Computes the weights for each term in a GAM. * * @param gam the GAM model. * @param instances the training set. * @return the list of weights for each term in a GAM. */ public static List> diagnose(GAM gam, Instances instances, Mode mode) { List> list = new ArrayList<>(); Map> map = new HashMap<>(); List terms = gam.getTerms(); List regressors = gam.getRegressors(); for (int i = 0; i < terms.size(); i++) { int[] term = terms.get(i); if (!map.containsKey(term)) { map.put(term, new ArrayList()); } Regressor regressor = regressors.get(i); map.get(term).add(regressor); } double[] predictions = new double[instances.size()]; for (int[] term : map.keySet()) { List regressorList = map.get(term); for (int i = 0; i < instances.size(); i++) { predictions[i] = 0; Instance instance = instances.get(i); for (Regressor regressor : regressorList) { predictions[i] += regressor.regress(instance); } } double weight = 0; if (mode == Mode.L2) { weight = StatUtils.variance(predictions); } else { double mean = StatUtils.mean(predictions); weight = StatUtils.mad(predictions, mean); } list.add(new Element(term, weight)); } return list; } static class Options { @Argument(name = "-r", description = "attribute file path") String attPath = null; @Argument(name = "-d", description = "dataset path", required = true) String datasetPath = null; @Argument(name = "-m", description = "mode (L1 or L2, default: L2)") String mode = null; @Argument(name = "-i", description = "input model path", required = true) String inputModelPath = null; @Argument(name = "-o", description = "output path", required = true) String outputPath = null; } /** * Generates term importance for GAMs. * *
	 * Usage: mltk.predictor.gam.tool.Diagnostics
	 * -d	dataset path
	 * -i	input model path
	 * -o	output path
	 * [-r]	attribute file path
	 * [-m]	mode (L1 or L2, default: L2)
	 * 
* * @param args the command line arguments. * @throws Exception */ public static void main(String[] args) throws Exception { Options opts = new Options(); CmdLineParser parser = new CmdLineParser(Diagnostics.class, opts); try { parser.parse(args); } catch (IllegalArgumentException e) { parser.printUsage(); System.exit(1); } Instances dataset = InstancesReader.read(opts.attPath, opts.datasetPath); GAM gam = PredictorReader.read(opts.inputModelPath, GAM.class); List> list = Diagnostics.diagnose(gam, dataset, Mode.getEnum(opts.mode)); Collections.sort(list); Collections.reverse(list); PrintWriter out = new PrintWriter(opts.outputPath); for (Element element : list) { int[] term = element.element; double weight = element.weight; out.println(Arrays.toString(term) + ": " + weight); } out.flush(); out.close(); } } ================================================ FILE: src/main/java/mltk/predictor/gam/tool/Visualizer.java ================================================ package mltk.predictor.gam.tool; import java.io.File; import java.io.IOException; import java.io.PrintWriter; import java.util.ArrayList; import java.util.Collections; import java.util.HashMap; import java.util.HashSet; import java.util.List; import java.util.Map; import java.util.Set; import mltk.cmdline.Argument; import mltk.cmdline.CmdLineParser; import mltk.core.Attribute; import mltk.core.BinnedAttribute; import mltk.core.Bins; import mltk.core.Instance; import mltk.core.Instances; import mltk.core.NominalAttribute; import mltk.core.io.InstancesReader; import mltk.predictor.Regressor; import mltk.predictor.function.Array1D; import mltk.predictor.function.Array2D; import mltk.predictor.function.CubicSpline; import mltk.predictor.function.Function1D; import mltk.predictor.gam.GAM; import mltk.predictor.io.PredictorReader; import mltk.util.MathUtils; /** * Class for visualizing 1D and 2D components in a GAM. * * @author Yin Lou * */ public class Visualizer { /** * Enumeration of output terminals. * * @author Yin Lou * */ public enum Terminal { /** * PNG terminal. */ PNG("png"), /** * PDF terminal. */ PDF("pdf"); String term; Terminal(String term) { this.term = term; } public String toString() { return term; } /** * Parses an enumeration from a string. * * @param term the string. * @return a parsed terminal. */ public static Terminal getEnum(String term) { for (Terminal re : Terminal.values()) { if (re.term.compareTo(term) == 0) { return re; } } throw new IllegalArgumentException("Invalid Terminal value: " + term); } } /** * Generates a set of Gnuplot scripts for visualizing low dimensional components in a GAM. * * @param gam the GAM model. * @param instances the training set. * @param dirPath the directory path to write to. * @param outputTerminal output plot format (png or pdf). * @throws IOException */ public static void generateGnuplotScripts(GAM gam, Instances instances, String dirPath, Terminal outputTerminal) throws IOException { List attributes = instances.getAttributes(); int p = -1; Map attMap = new HashMap<>(attributes.size()); for (Attribute attribute : attributes) { int attIndex = attribute.getIndex(); attMap.put(attIndex, attribute); if (attIndex > p) { p = attIndex; } } p++; List terms = gam.getTerms(); List regressors = gam.getRegressors(); File dir = new File(dirPath); if (!dir.exists()) { dir.mkdirs(); } double[] value = new double[p]; Instance point = new Instance(value); String terminal = outputTerminal.toString(); for (int i = 0; i < terms.size(); i++) { int[] term = terms.get(i); Regressor regressor = regressors.get(i); if (term.length == 1) { Attribute f = attMap.get(term[0]); switch (f.getType()) { case BINNED: int numBins = ((BinnedAttribute) f).getNumBins(); if (numBins == 1) { continue; } break; case NOMINAL: int numStates = ((NominalAttribute) f).getStates().length; if (numStates == 1) { continue; } break; default: break; } Double predictionOnMV = null; if (regressor instanceof Function1D) { double predOnMV = ((Function1D) regressor).getPredictionOnMV(); if (!MathUtils.isZero(predOnMV)) { predictionOnMV = predOnMV; } } else if (regressor instanceof Array1D) { double predOnMV = ((Array1D) regressor).getPredictionOnMV(); if (!MathUtils.isZero(predOnMV)) { predictionOnMV = predOnMV; } } PrintWriter out = new PrintWriter(dir.getAbsolutePath() + File.separator + f.getName() + ".plt"); out.printf("set term %s\n", terminal); out.printf("set output \"%s.%s\"\n", f.getName(), terminal); out.println("set datafile separator \"\t\""); out.println("set grid"); if (predictionOnMV != null) { out.println("set multiplot layout 1,2 rowsfirst"); } // Plot main function switch (f.getType()) { case BINNED: int numBins = ((BinnedAttribute) f).getNumBins(); Bins bins = ((BinnedAttribute) f).getBins(); double[] boundaries = bins.getBoundaries(); double start = boundaries[0] - 1; if (boundaries.length >= 2) { start = boundaries[0] - (boundaries[1] - boundaries[0]); } out.printf("set xrange[%f:%f]\n", start, boundaries[boundaries.length - 1]); List predList = new ArrayList<>(); for (int j = 0; j < numBins; j++) { point.setValue(term[0], j); predList.add(regressor.regress(point)); } point.setValue(term[0], 0); {// Writing plot data to file String fileName = f.getName() + ".dat"; out.println("plot \"" + fileName + "\" u 1:2 w l t \"\""); PrintWriter writer = new PrintWriter(dir.getAbsolutePath() + File.separator + fileName); writer.printf("%f\t%f\n", start, predList.get(0)); for (int j = 0; j < numBins; j++) { point.setValue(term[0], j); writer.printf("%f\t%f\n", boundaries[j], predList.get(j)); if (j < numBins - 1) { writer.printf("%f\t%f\n", boundaries[j], predList.get(j + 1)); } } writer.flush(); writer.close(); } break; case NOMINAL: out.println("set style data histogram"); out.println("set style histogram cluster gap 1"); out.println("set style fill solid border -1"); out.println("set boxwidth 0.9"); out.println("set xtic rotate by -90"); String[] states = ((NominalAttribute) f).getStates(); {// Writing plot data to file String fileName = f.getName() + ".dat"; out.println("plot \"" + fileName + "\" u 2:xtic(1) t \"\""); PrintWriter writer = new PrintWriter(dir.getAbsolutePath() + File.separator + fileName); for (int j = 0; j < states.length; j++) { point.setValue(term[0], j); writer.printf("%s\t%f\n", states[j], regressor.regress(point)); } writer.flush(); writer.close(); } break; default: Set values = new HashSet<>(); for (Instance instance : instances) { values.add(instance.getValue(term[0])); } List list = new ArrayList<>(values); Collections.sort(list); out.printf("set xrange[%f:%f]\n", list.get(0), list.get(list.size() - 1)); if (regressor instanceof CubicSpline) { CubicSpline spline = (CubicSpline) regressor; out.println("z(x) = x < 0 ? 0 : x ** 3"); out.println("h(x, k) = z(x - k)"); double[] knots = spline.getKnots(); double[] w = spline.getCoefficients(); StringBuilder sb = new StringBuilder(); sb.append("plot ").append(spline.getIntercept()); sb.append(" + ").append(w[0]).append(" * x"); sb.append(" + ").append(w[1]).append(" * (x ** 2)"); sb.append(" + ").append(w[2]).append(" * (x ** 3)"); for (int j = 0; j < knots.length; j++) { sb.append(" + ").append(w[j + 3]).append(" * "); sb.append("h(x, ").append(knots[j]).append(")"); } sb.append(" t \"\""); out.println(sb.toString()); } else { out.println("plot \"-\" u 1:2 w lp t \"\""); for (double v : list) { point.setValue(term[0], v); out.printf("%f\t%f\n", v, regressor.regress(point)); } out.println("e"); } break; } // Plot prediction on missing value if (predictionOnMV != null) { out.println("set style fill solid border -1"); out.println("set xtic rotate by 0"); out.println("plot \"-\" using 2:xtic(1) with histogram t \"\""); out.println("missing value\t" + predictionOnMV); out.println("e"); } out.flush(); out.close(); } else if (term.length == 2) { Attribute f1 = attMap.get(term[0]); Attribute f2 = attMap.get(term[1]); String fileName = f1.getName() + "_" + f2.getName(); PrintWriter out = new PrintWriter(dir.getAbsolutePath() + File.separator + fileName + ".plt"); out.printf("set term %s\n", terminal); out.printf("set output \"%s_%s.%s\"\n", f1.getName(), f2.getName(), terminal); out.println("set datafile separator \"\t\""); int numRow = 1; int numCol = 1; double[] predictionsOnMV1 = null; double[] predictionsOnMV2 = null; double predictionsOnMV12 = 0.0; if (regressor instanceof Array2D) { Array2D ary2d = (Array2D) regressor; for (double v : ary2d.getPredictionsOnMV1()) { if (!MathUtils.isZero(v)) { numRow = 2; predictionsOnMV1 = ary2d.getPredictionsOnMV1(); break; } } for (double v : ary2d.getPredictionsOnMV2()) { if (!MathUtils.isZero(v)) { numCol = 2; predictionsOnMV2 = ary2d.getPredictionsOnMV2(); break; } } if (!MathUtils.isZero(ary2d.getPredictionOnMV12())) { numRow = 2; numCol = 2; predictionsOnMV12 = ary2d.getPredictionOnMV12(); } } if (numRow > 1 || numCol > 1) { out.printf("set multiplot layout %d,%d rowsfirst\n", numRow, numCol); } int size1 = 0; if (f1.getType() == Attribute.Type.BINNED) { size1 = ((BinnedAttribute) f1).getNumBins(); } else if (f1.getType() == Attribute.Type.NOMINAL) { size1 = ((NominalAttribute) f1).getCardinality(); } int size2 = 0; if (f2.getType() == Attribute.Type.BINNED) { size2 = ((BinnedAttribute) f2).getNumBins(); } else if (f2.getType() == Attribute.Type.NOMINAL) { size2 = ((NominalAttribute) f2).getCardinality(); } if (f1.getType() == Attribute.Type.NOMINAL) { out.print("set ytics("); String[] states = ((NominalAttribute) f1).getStates(); for (int j = 0; j < states.length - 1; j++) { out.printf("\"%s\" %d, ", states[j], j); } out.printf("\"%s\" %d)\n", states[states.length - 1], states.length - 1); } if (f2.getType() == Attribute.Type.NOMINAL) { out.print("set xtics("); String[] states = ((NominalAttribute) f2).getStates(); for (int j = 0; j < states.length - 1; j++) { out.printf("\"%s\" %d, ", states[j], j); } out.printf("\"%s\" %d) rotate\n", states[states.length - 1], states.length - 1); } out.println("unset border"); out.println("set view map"); out.println("set style data pm3d"); out.println("set style function pm3d"); out.println("set pm3d corners2color c4"); double[] rangeY = null; double[] valueY = null; double[] rangeX = null; double[] valueX = null; if (f1 instanceof BinnedAttribute) { rangeY = new double[size1 + 1]; valueY = new double[size1 + 1]; Bins bins = ((BinnedAttribute) f1).getBins(); double[] boundaries = bins.getBoundaries(); double start = boundaries[0] - 1; if (boundaries.length >= 2) { start = boundaries[0] - (boundaries[1] - boundaries[0]); } rangeY[0] = start; valueY[0] = 0; for (int k = 0; k < boundaries.length; k++) { rangeY[k + 1] = boundaries[k]; valueY[k + 1] = k; } } else if (f1 instanceof NominalAttribute) { rangeY = new double[size1]; valueY = new double[size1]; for (int k = 0; k < size1; k++) { rangeY[k] = k; valueY[k] = k; } } out.printf("set yrange[%f:%f]\n", rangeY[0] - 1, rangeY[rangeY.length - 1] + 1); if (f2 instanceof BinnedAttribute) { rangeX = new double[size2 + 1]; valueX = new double[size2 + 1]; Bins bins = ((BinnedAttribute) f2).getBins(); double[] boundaries = bins.getBoundaries(); double start = boundaries[0] - 1; if (boundaries.length >= 2) { start = boundaries[0] - (boundaries[1] - boundaries[0]); } rangeX[0] = start; valueX[0] = 0; for (int k = 0; k < boundaries.length; k++) { rangeX[k + 1] = boundaries[k]; valueX[k + 1] = k; } } else if (f2 instanceof NominalAttribute) { rangeX = new double[size2]; valueX = new double[size2]; for (int k = 0; k < size2; k++) { rangeX[k] = k; valueX[k] = k; } } out.printf("set xrange[%f:%f]\n", rangeX[0] - 1, rangeX[rangeX.length - 1] + 1); out.println("splot \"" + fileName + ".dat\" with image t \"\""); PrintWriter writer = new PrintWriter(dir.getAbsolutePath() + File.separator + fileName + ".dat"); for (int r = 0; r < rangeY.length; r++) { point.setValue(term[0], valueY[r]); for (int c = 0; c < rangeX.length; c++) { point.setValue(term[1], valueX[c]); writer.println(rangeX[c] + "\t" + rangeY[r] + "\t" + gam.regress(point)); } writer.println(); } writer.flush(); writer.close(); if (numRow > 1 || numCol > 1) { // multiplot is on if (predictionsOnMV2 != null) { out.println("reset"); writer = new PrintWriter(dir.getAbsolutePath() + File.separator + fileName + "_mv2.dat"); if (f1.getType() == Attribute.Type.NOMINAL) { String[] states = ((NominalAttribute) f1).getStates(); out.println("set style data histogram"); out.println("set style histogram cluster gap 1"); out.println("set style fill solid border -1"); out.println("set boxwidth 0.9"); out.println("set xtic rotate by -90"); out.println("plot \"" + fileName + "_mv2.dat\" u 2:xtic(1) t \"\""); for (int k = 0; k < states.length; k++) { writer.println(states[k] + "\t" + predictionsOnMV2[k]); } } else if (f1.getType() == Attribute.Type.BINNED) { out.println("plot \"" + fileName + "_mv2.dat\" u 1:2 w l t \"\""); writer.println(rangeY[0] + "\t" + predictionsOnMV2[0]); for (int k = 0; k < predictionsOnMV2.length; k++) { writer.println(rangeY[k + 1] + "\t" + predictionsOnMV2[k]); } } writer.flush(); writer.close(); } else if (numCol > 1) { out.println("set multiplot next"); } if (predictionsOnMV1 != null) { out.println("reset"); writer = new PrintWriter(dir.getAbsolutePath() + File.separator + fileName + "_mv1.dat"); if (f2.getType() == Attribute.Type.NOMINAL) { String[] states = ((NominalAttribute) f2).getStates(); out.println("set style data histogram"); out.println("set style histogram cluster gap 1"); out.println("set style fill solid border -1"); out.println("set boxwidth 0.9"); out.println("set xtic rotate by -90"); out.println("plot \"" + fileName + "_mv1.dat\" u 2:xtic(1) t \"\""); for (int k = 0; k < states.length; k++) { writer.println(states[k] + "\t" + predictionsOnMV1[k]); } } else if (f2.getType() == Attribute.Type.BINNED) { out.println("plot \"" + fileName + "_mv1.dat\" u 1:2 w l t \"\""); writer.println(rangeX[0] + "\t" + predictionsOnMV1[0]); for (int k = 0; k < predictionsOnMV1.length; k++) { writer.println(rangeX[k + 1] + "\t" + predictionsOnMV1[k]); } } writer.flush(); writer.close(); } else if (numRow > 1) { out.println("set multiplot next"); } if (!MathUtils.isZero(predictionsOnMV12)) { out.println("reset"); out.println("set datafile separator \"\t\""); out.println("set style fill solid border -1"); out.println("set xtic rotate by 0"); out.println("plot \"-\" using 2:xtic(1) with histogram t \"\""); out.println("missing value\t" + predictionsOnMV12); out.println("e"); } else if (numRow > 1 && numCol > 1) { out.println("set multiplot next"); } } out.flush(); out.close(); } } } static class Options { @Argument(name = "-r", description = "attribute file path", required = true) String attPath = null; @Argument(name = "-d", description = "dataset path", required = true) String datasetPath = null; @Argument(name = "-i", description = "input model path", required = true) String inputModelPath = null; @Argument(name = "-o", description = "output directory path", required = true) String dirPath = null; @Argument(name = "-t", description = "output terminal (default: png)") String terminal = "png"; } /** * Generates scripts for visualizing GAMs. * *
	 * Usage: mltk.predictor.gam.tool.Visualizer
	 * -r	attribute file path
	 * -d	dataset path
	 * -i	input model path
	 * -o	output directory path
	 * [-t]	output terminal (default: png)
	 * 
* * @param args the command line arguments. * @throws Exception */ public static void main(String[] args) throws Exception { Options opts = new Options(); CmdLineParser parser = new CmdLineParser(Visualizer.class, opts); try { parser.parse(args); } catch (IllegalArgumentException e) { parser.printUsage(); System.exit(1); } Instances dataset = InstancesReader.read(opts.attPath, opts.datasetPath); GAM gam = PredictorReader.read(opts.inputModelPath, GAM.class); Visualizer.generateGnuplotScripts(gam, dataset, opts.dirPath, Terminal.getEnum(opts.terminal)); } } ================================================ FILE: src/main/java/mltk/predictor/gam/tool/package-info.java ================================================ /** * Provides tools for diagnosing and visualizing GAMs. */ package mltk.predictor.gam.tool; ================================================ FILE: src/main/java/mltk/predictor/glm/ElasticNetLearner.java ================================================ package mltk.predictor.glm; import java.util.Arrays; import mltk.cmdline.Argument; import mltk.cmdline.CmdLineParser; import mltk.cmdline.options.LearnerWithTaskOptions; import mltk.core.Attribute; import mltk.core.DenseVector; import mltk.core.Instances; import mltk.core.NominalAttribute; import mltk.core.SparseVector; import mltk.core.io.InstancesReader; import mltk.predictor.Family; import mltk.predictor.LinkFunction; import mltk.predictor.io.PredictorWriter; import mltk.util.MathUtils; import mltk.util.OptimUtils; import mltk.util.StatUtils; import mltk.util.VectorUtils; /** * Class for learning elastic-net penalized linear model via coordinate descent. * * @author Yin Lou * */ public class ElasticNetLearner extends GLMLearner { static class Options extends LearnerWithTaskOptions { @Argument(name = "-m", description = "maximum number of iterations (default: 0)") int maxIter = 0; @Argument(name = "-l", description = "lambda (default: 0)") double lambda = 0; @Argument(name = "-a", description = "L1 ratio (default: 0)") double l1Ratio = 0; } /** * Trains elastic-net-regularized GLMs. * *
	 * Usage: mltk.predictor.glm.ElasticNetLearner
	 * -t	train set path
	 * [-g]	task between classification (c) and regression (r) (default: r)
	 * [-r]	attribute file path
	 * [-o]	output model path
	 * [-V]	verbose (default: true)
	 * [-m]	maximum number of iterations (default: 0)
	 * [-l]	lambda (default: 0)
	 * [-a]	L1 ratio (default: 0)
	 * 
* * @param args the command line arguments. * @throws Exception */ public static void main(String[] args) throws Exception { Options opts = new Options(); CmdLineParser parser = new CmdLineParser(ElasticNetLearner.class, opts); Task task = null; try { parser.parse(args); task = Task.get(opts.task); } catch (IllegalArgumentException e) { parser.printUsage(); System.exit(1); } Instances trainSet = InstancesReader.read(opts.attPath, opts.trainPath); ElasticNetLearner learner = new ElasticNetLearner(); learner.setVerbose(opts.verbose); learner.setTask(task); learner.setLambda(opts.lambda); learner.setL1Ratio(opts.l1Ratio); learner.setMaxNumIters(opts.maxIter); long start = System.currentTimeMillis(); GLM glm = learner.build(trainSet); long end = System.currentTimeMillis(); System.out.println("Time: " + (end - start) / 1000.0); if (opts.outputModelPath != null) { PredictorWriter.write(glm, opts.outputModelPath); } } protected double lambda; protected double l1Ratio; protected Task task; /** * Constructor. */ public ElasticNetLearner() { lambda = 0; // no regularization l1Ratio = 0; // 0: ridge, 1: lasso, (0, 1): elastic net task = Task.REGRESSION; } @Override public GLM build(Instances instances) { GLM glm = null; if (maxNumIters < 0) { maxNumIters = 20; } switch (task) { case REGRESSION: glm = buildGaussianRegressor(instances, maxNumIters, lambda, l1Ratio); break; case CLASSIFICATION: glm = buildClassifier(instances, maxNumIters, lambda, l1Ratio); break; default: break; } return glm; } @Override public GLM build(Instances trainSet, Family family) { GLM glm = null; if (maxNumIters < 0) { maxNumIters = 20; } switch (family) { case GAUSSIAN: glm = buildGaussianRegressor(trainSet, maxNumIters, lambda, l1Ratio); break; case BINOMIAL: glm = buildClassifier(trainSet, maxNumIters, lambda, l1Ratio); break; default: throw new IllegalArgumentException("Unsupported family: " + family); } return glm; } /** * Builds an elastic-net penalized binary classifier. Each row in the input matrix x represents a feature (instead * of a data point). Thus the input matrix is the transpose of the row-oriented data matrix. This procedure does not * assume the data is normalized or centered. * * @param attrs the attribute list. * @param x the inputs. * @param y the targets. * @param maxNumIters the maximum number of iterations. * @param lambda the lambda. * @param l1Ratio the L1 ratio. * @return an elastic-net penalized binary classifier. */ public GLM buildBinaryClassifier(int[] attrs, double[][] x, double[] y, int maxNumIters, double lambda, double l1Ratio) { double[] w = new double[attrs.length]; double intercept = 0; double[] pTrain = new double[y.length]; double[] rTrain = new double[y.length]; OptimUtils.computePseudoResidual(pTrain, y, rTrain); // Calculate theta's double[] theta = new double[x.length]; for (int i = 0; i < x.length; i++) { theta[i] = StatUtils.sumSq(x[i]) / 4; } final double lambda1 = lambda * l1Ratio; final double lambda2 = lambda * (1 - l1Ratio); final double tl1 = lambda1 * y.length; final double tl2 = lambda2 * y.length; // Coordinate descent for (int iter = 0; iter < maxNumIters; iter++) { double prevLoss = GLMOptimUtils.computeElasticNetLoss(pTrain, y, w, lambda1, lambda2); if (fitIntercept) { intercept += OptimUtils.fitIntercept(pTrain, rTrain, y); } doOnePassBinomial(x, theta, y, tl1, tl2, w, pTrain, rTrain); double currLoss = GLMOptimUtils.computeElasticNetLoss(pTrain, y, w, lambda1, lambda2); if (verbose) { System.out.println("Iteration " + iter + ": " + " " + currLoss); } if (OptimUtils.isConverged(prevLoss, currLoss, epsilon)) { break; } } return GLMOptimUtils.getGLM(attrs, w, intercept, LinkFunction.LOGIT); } /** * Builds an elastic-net penalized binary classifier on sparse inputs. Each row of the input represents a feature * (instead of a data point), i.e., in column-oriented format. This procedure does not assume the data is normalized * or centered. * * @param attrs the attribute list. * @param indices the indices. * @param values the values. * @param y the targets. * @param maxNumIters the maximum number of iterations. * @param lambda the lambda. * @param l1Ratio the L1 ratio. * @return an elastic-net penalized classifier. */ public GLM buildBinaryClassifier(int[] attrs, int[][] indices, double[][] values, double[] y, int maxNumIters, double lambda, double l1Ratio) { double[] w = new double[attrs.length]; double intercept = 0; double[] pTrain = new double[y.length]; double[] rTrain = new double[y.length]; OptimUtils.computePseudoResidual(pTrain, y, rTrain); // Calculate theta's double[] theta = new double[values.length]; for (int i = 0; i < values.length; i++) { theta[i] = StatUtils.sumSq(values[i]) / 4; } final double lambda1 = lambda * l1Ratio; final double lambda2 = lambda * (1 - l1Ratio); final double tl1 = lambda1 * y.length; final double tl2 = lambda2 * y.length; // Coordinate descent for (int iter = 0; iter < maxNumIters; iter++) { double prevLoss = GLMOptimUtils.computeElasticNetLoss(pTrain, y, w, lambda1, lambda2); if (fitIntercept) { intercept += OptimUtils.fitIntercept(pTrain, rTrain, y); } doOnePassBinomial(indices, values, theta, y, tl1, tl2, w, pTrain, rTrain); double currLoss = GLMOptimUtils.computeElasticNetLoss(pTrain, y, w, lambda1, lambda2); if (verbose) { System.out.println("Iteration " + iter + ": " + " " + currLoss); } if (OptimUtils.isConverged(prevLoss, currLoss, epsilon)) { break; } } return GLMOptimUtils.getGLM(attrs, w, intercept, LinkFunction.LOGIT); } /** * Builds elastic-net penalized binary classifiers for a sequence of regularization parameter lambdas. Each row in * the input matrix x represents a feature (instead of a data point). Thus the input matrix is the transpose of the * row-oriented data matrix. This procedure does not assume the data is normalized or centered. * * @param attrs the attribute list. * @param x the inputs. * @param y the targets. * @param maxNumIters the maximum number of iterations. * @param numLambdas the number of lambdas. * @param minLambdaRatio the minimum lambda is minLambdaRatio * max lambda. * @param l1Ratio the L1 ratio. * @return elastic-net penalized classifiers. */ public GLM[] buildBinaryClassifiers(int[] attrs, double[][] x, double[] y, int maxNumIters, int numLambdas, double minLambdaRatio, double l1Ratio) { double[] w = new double[attrs.length]; double intercept = 0; double[] pTrain = new double[y.length]; double[] rTrain = new double[y.length]; OptimUtils.computePseudoResidual(pTrain, y, rTrain); // Calculate theta's double[] theta = new double[x.length]; for (int i = 0; i < x.length; i++) { theta[i] = StatUtils.sumSq(x[i]) / 4; } double maxLambda = findMaxLambdaBinomial(x, y, pTrain, rTrain, l1Ratio); // Dampening factor for lambda double alpha = Math.pow(minLambdaRatio, 1.0 / numLambdas); GLM[] glms = new GLM[numLambdas]; double lambda = maxLambda; for (int g = 0; g < numLambdas; g++) { final double lambda1 = lambda * l1Ratio; final double lambda2 = lambda * (1 - l1Ratio); final double tl1 = lambda1 * y.length; final double tl2 = lambda2 * y.length; // Coordinate descent for (int iter = 0; iter < maxNumIters; iter++) { double prevLoss = GLMOptimUtils.computeElasticNetLoss(pTrain, y, w, lambda1, lambda2); if (fitIntercept) { intercept += OptimUtils.fitIntercept(pTrain, rTrain, y); } doOnePassBinomial(x, theta, y, tl1, tl2, w, pTrain, rTrain); double currLoss = GLMOptimUtils.computeElasticNetLoss(pTrain, y, w, lambda1, lambda2); if (verbose) { System.out.println("Iteration " + iter + ": " + " " + currLoss); } if (OptimUtils.isConverged(prevLoss, currLoss, epsilon)) { break; } } glms[g] = GLMOptimUtils.getGLM(attrs, w, intercept, LinkFunction.LOGIT); lambda *= alpha; } return glms; } /** * Builds elastic-net penalized binary classifiers on sparse inputs for a sequence of regularization parameter * lambdas. Each row of the input represents a feature (instead of a data point), i.e., in column-oriented format. * This procedure does not assume the data is normalized or centered. * * @param attrs the attribute list. * @param indices the indices. * @param values the values. * @param y the targets. * @param maxNumIters the maximum number of iterations. * @param numLambdas the number of lambdas. * @param minLambdaRatio the minimum lambda is minLambdaRatio * max lambda. * @param l1Ratio the L1 ratio. * @return an elastic-net penalized classifier. */ public GLM[] buildBinaryClassifiers(int[] attrs, int[][] indices, double[][] values, double[] y, int maxNumIters, int numLambdas, double minLambdaRatio, double l1Ratio) { double[] w = new double[attrs.length]; double intercept = 0; double[] pTrain = new double[y.length]; double[] rTrain = new double[y.length]; OptimUtils.computePseudoResidual(pTrain, y, rTrain); // Calculate theta's double[] theta = new double[values.length]; for (int i = 0; i < values.length; i++) { theta[i] = StatUtils.sumSq(values[i]) / 4; } double maxLambda = findMaxLambdaBinomial(indices, values, y, pTrain, rTrain, l1Ratio); // Dampening factor for lambda double alpha = Math.pow(minLambdaRatio, 1.0 / numLambdas); GLM[] glms = new GLM[numLambdas]; double lambda = maxLambda; for (int g = 0; g < glms.length; g++) { final double lambda1 = lambda * l1Ratio; final double lambda2 = lambda * (1 - l1Ratio); final double tl1 = lambda1 * y.length; final double tl2 = lambda2 * y.length; // Coordinate descent for (int iter = 0; iter < maxNumIters; iter++) { double prevLoss = GLMOptimUtils.computeElasticNetLoss(pTrain, y, w, lambda1, lambda2); if (fitIntercept) { intercept += OptimUtils.fitIntercept(pTrain, rTrain, y); } doOnePassBinomial(indices, values, theta, y, tl1, tl2, w, pTrain, rTrain); double currLoss = GLMOptimUtils.computeElasticNetLoss(pTrain, y, w, lambda1, lambda2); if (verbose) { System.out.println("Iteration " + iter + ": " + " " + currLoss); } if (OptimUtils.isConverged(prevLoss, currLoss, epsilon)) { break; } } glms[g] = GLMOptimUtils.getGLM(attrs, w, intercept, LinkFunction.LOGIT); lambda *= alpha; } return glms; } /** * Builds an elastic-net penalized classifier. * * @param trainSet the training set. * @param isSparse {@code true} if the training set is treated as sparse. * @param maxNumIters the maximum number of iterations. * @param lambda the lambda. * @param l1Ratio the l1 ratio. * @return an elastic-net penalized classifer. */ public GLM buildClassifier(Instances trainSet, boolean isSparse, int maxNumIters, double lambda, double l1Ratio) { Attribute classAttribute = trainSet.getTargetAttribute(); if (classAttribute.getType() != Attribute.Type.NOMINAL) { throw new IllegalArgumentException("Class attribute must be nominal."); } NominalAttribute clazz = (NominalAttribute) classAttribute; int numClasses = clazz.getCardinality(); if (isSparse) { SparseDataset sd = getSparseDataset(trainSet, true); int[] attrs = sd.attrs; int[][] indices = sd.indices; double[][] values = sd.values; double[] y = new double[sd.y.length]; double[] cList = sd.cList; if (numClasses == 2) { for (int i = 0; i < y.length; i++) { int label = (int) sd.y[i]; y[i] = label == 0 ? 1 : 0; } GLM glm = buildBinaryClassifier(attrs, indices, values, y, maxNumIters, lambda, l1Ratio); double[] w = glm.w[0]; for (int j = 0; j < cList.length; j++) { int attIndex = attrs[j]; w[attIndex] *= cList[j]; } return glm; } else { int p = attrs.length == 0 ? 0 : attrs[attrs.length - 1] + 1; GLM glm = new GLM(numClasses, p); for (int k = 0; k < numClasses; k++) { // One-vs-the-rest for (int i = 0; i < y.length; i++) { int label = (int) sd.y[i]; y[i] = label == k ? 1 : 0; } GLM binaryClassifier = buildBinaryClassifier(attrs, indices, values, y, maxNumIters, lambda, l1Ratio); double[] w = binaryClassifier.w[0]; for (int j = 0; j < cList.length; j++) { int attIndex = attrs[j]; glm.w[k][attIndex] = w[attIndex] * cList[j]; } glm.intercept[k] = binaryClassifier.intercept[0]; } return glm; } } else { DenseDataset dd = getDenseDataset(trainSet, true); int[] attrs = dd.attrs; double[][] x = dd.x; double[] y = new double[dd.y.length]; double[] cList = dd.cList; if (numClasses == 2) { for (int i = 0; i < y.length; i++) { int label = (int) dd.y[i]; y[i] = label == 0 ? 1 : 0; } GLM glm = buildBinaryClassifier(attrs, x, y, maxNumIters, lambda, l1Ratio); double[] w = glm.w[0]; for (int j = 0; j < cList.length; j++) { int attIndex = attrs[j]; w[attIndex] *= cList[j]; } return glm; } else { int p = attrs.length == 0 ? 0 : (StatUtils.max(attrs) + 1); GLM glm = new GLM(numClasses, p); for (int k = 0; k < numClasses; k++) { // One-vs-the-rest for (int i = 0; i < y.length; i++) { int label = (int) dd.y[i]; y[i] = label == k ? 1 : 0; } GLM binaryClassifier = buildBinaryClassifier(attrs, x, y, maxNumIters, lambda, l1Ratio); double[] w = binaryClassifier.w[0]; for (int j = 0; j < cList.length; j++) { int attIndex = attrs[j]; glm.w[k][attIndex] = w[attIndex] * cList[j]; } glm.intercept[k] = binaryClassifier.intercept[0]; } return glm; } } } /** * Builds an elastic-net penalized classifier. * * @param trainSet the training set. * @param maxNumIters the maximum number of iterations. * @param lambda the lambda. * @param l1Ratio the l1 ratio. * @return an elastic-net penalized classifer. */ public GLM buildClassifier(Instances trainSet, int maxNumIters, double lambda, double l1Ratio) { return buildClassifier(trainSet, isSparse(trainSet), maxNumIters, lambda, l1Ratio); } /** * Builds elastic-net penalized classifiers. * * @param trainSet the training set. * @param isSparse {@code true} if the training set is treated as sparse. * @param maxNumIters the maximum number of iterations. * @param numLambdas the number of lambdas. * @param minLambdaRatio the minimum lambda is minLambdaRatio * max lambda. * @param l1Ratio the l1 ratio. * @return elastic-net penalized classifers. */ public GLM[] buildClassifiers(Instances trainSet, boolean isSparse, int maxNumIters, int numLambdas, double minLambdaRatio, double l1Ratio) { Attribute classAttribute = trainSet.getTargetAttribute(); if (classAttribute.getType() != Attribute.Type.NOMINAL) { throw new IllegalArgumentException("Class attribute must be nominal."); } NominalAttribute clazz = (NominalAttribute) classAttribute; int numClasses = clazz.getCardinality(); if (isSparse) { SparseDataset sd = getSparseDataset(trainSet, true); int[] attrs = sd.attrs; int[][] indices = sd.indices; double[][] values = sd.values; double[] y = new double[sd.y.length]; double[] cList = sd.cList; if (numClasses == 2) { for (int i = 0; i < y.length; i++) { int label = (int) sd.y[i]; y[i] = label == 0 ? 1 : 0; } GLM[] glms = buildBinaryClassifiers(attrs, indices, values, y, maxNumIters, numLambdas, minLambdaRatio, l1Ratio); for (GLM glm : glms) { double[] w = glm.w[0]; for (int j = 0; j < cList.length; j++) { int attIndex = attrs[j]; w[attIndex] *= cList[j]; } } return glms; } else { int p = attrs.length == 0 ? 0 : (StatUtils.max(attrs) + 1); GLM[] glms = new GLM[numLambdas]; for (int i = 0; i < glms.length; i++) { glms[i] = new GLM(numClasses, p); } for (int k = 0; k < numClasses; k++) { // One-vs-the-rest for (int i = 0; i < y.length; i++) { int label = (int) sd.y[i]; y[i] = label == k ? 1 : 0; } GLM[] binaryClassifiers = buildBinaryClassifiers(attrs, indices, values, y, maxNumIters, numLambdas, minLambdaRatio, l1Ratio); for (int l = 0; l < glms.length; l++) { GLM binaryClassifier = binaryClassifiers[l]; GLM glm = glms[l]; double[] w = binaryClassifier.w[0]; for (int j = 0; j < cList.length; j++) { int attIndex = attrs[j]; glm.w[k][attIndex] = w[attIndex] * cList[j]; } glm.intercept[k] = binaryClassifier.intercept[0]; } } return glms; } } else { DenseDataset dd = getDenseDataset(trainSet, true); int[] attrs = dd.attrs; double[][] x = dd.x; double[] y = new double[dd.y.length]; double[] cList = dd.cList; if (numClasses == 2) { for (int i = 0; i < y.length; i++) { int label = (int) dd.y[i]; y[i] = label == 0 ? 1 : 0; } GLM[] glms = buildBinaryClassifiers(attrs, x, y, maxNumIters, numLambdas, minLambdaRatio, l1Ratio); for (GLM glm : glms) { double[] w = glm.w[0]; for (int j = 0; j < cList.length; j++) { int attIndex = attrs[j]; w[attIndex] *= cList[j]; } } return glms; } else { int p = attrs.length == 0 ? 0 : (StatUtils.max(attrs) + 1); GLM[] glms = new GLM[numLambdas]; for (int i = 0; i < glms.length; i++) { glms[i] = new GLM(numClasses, p); } for (int k = 0; k < numClasses; k++) { // One-vs-the-rest for (int i = 0; i < y.length; i++) { int label = (int) dd.y[i]; y[i] = label == k ? 1 : 0; } GLM[] binaryClassifiers = buildBinaryClassifiers(attrs, x, y, maxNumIters, numLambdas, minLambdaRatio, l1Ratio); for (int l = 0; l < glms.length; l++) { GLM binaryClassifier = binaryClassifiers[l]; GLM glm = glms[l]; double[] w = binaryClassifier.w[0]; for (int j = 0; j < cList.length; j++) { int attIndex = attrs[j]; glm.w[k][attIndex] = w[attIndex] * cList[j]; } glm.intercept[k] = binaryClassifier.intercept[0]; } } return glms; } } } /** * Builds elastic-net penalized classifiers. * * @param trainSet the training set. * @param maxNumIters the maximum number of iterations. * @param numLambdas the number of lambdas. * @param minLambdaRatio the minimum lambda is minLambdaRatio * max lambda. * @param l1Ratio the l1 ratio. * @return elastic-net penalized classifers. */ public GLM[] buildClassifiers(Instances trainSet, int maxNumIters, int numLambdas, double minLambdaRatio, double l1Ratio) { return buildClassifiers(trainSet, isSparse(trainSet), maxNumIters, numLambdas, minLambdaRatio, l1Ratio); } /** * Builds an elastic-net penalized regressor. * * @param trainSet the training set. * @param isSparse {@code true} if the training set is treated as sparse. * @param maxNumIters the maximum number of iterations. * @param lambda the lambda. * @param l1Ratio the L1 ratio. * @return an elastic-net penalized regressor. */ public GLM buildGaussianRegressor(Instances trainSet, boolean isSparse, int maxNumIters, double lambda, double l1Ratio) { if (isSparse) { SparseDataset sd = getSparseDataset(trainSet, true); double[] cList = sd.cList; GLM glm = buildGaussianRegressor(sd.attrs, sd.indices, sd.values, sd.y, maxNumIters, lambda, l1Ratio); double[] w = glm.w[0]; for (int j = 0; j < cList.length; j++) { int attIndex = sd.attrs[j]; w[attIndex] *= cList[j]; } return glm; } else { DenseDataset dd = getDenseDataset(trainSet, true); double[] cList = dd.cList; GLM glm = buildGaussianRegressor(dd.attrs, dd.x, dd.y, maxNumIters, lambda, l1Ratio); double[] w = glm.w[0]; for (int j = 0; j < cList.length; j++) { int attIndex = dd.attrs[j]; w[attIndex] *= cList[j]; } return glm; } } /** * Builds an elastic-net penalized regressor. * * @param trainSet the training set. * @param maxNumIters the maximum number of iterations. * @param lambda the lambda. * @param l1Ratio the L1 ratio. * @return an elastic-net penalized regressor. */ public GLM buildGaussianRegressor(Instances trainSet, int maxNumIters, double lambda, double l1Ratio) { return buildGaussianRegressor(trainSet, isSparse(trainSet), maxNumIters, lambda, l1Ratio); } /** * Builds an elastic-net penalized regressor. Each row in the input matrix x represents a feature (instead of a data * point). Thus the input matrix is the transpose of the row-oriented data matrix. This procedure does not assume * the data is normalized or centered. * * @param attrs the attribute list. * @param x the inputs. * @param y the targets. * @param maxNumIters the maximum number of iterations. * @param lambda the lambda. * @param l1Ratio the L1 ratio. * @return an elastic-net penalized regressor. */ public GLM buildGaussianRegressor(int[] attrs, double[][] x, double[] y, int maxNumIters, double lambda, double l1Ratio) { double[] w = new double[attrs.length]; double intercept = 0; // Initialize residuals double[] rTrain = new double[y.length]; for (int i = 0; i < rTrain.length; i++) { rTrain[i] = y[i]; } // Calculate sum of squares double[] sq = new double[x.length]; for (int i = 0; i < x.length; i++) { sq[i] = StatUtils.sumSq(x[i]); } final double lambda1 = lambda * l1Ratio; final double lambda2 = lambda * (1 - l1Ratio); final double tl1 = lambda1 * y.length; final double tl2 = lambda2 * y.length; // Coordinate descent for (int iter = 0; iter < maxNumIters; iter++) { double prevLoss = GLMOptimUtils.computeElasticNetLoss(rTrain, w, lambda1, lambda2); if (fitIntercept) { intercept += OptimUtils.fitIntercept(rTrain); } doOnePassGaussian(x, sq, tl1, tl2, w, rTrain); double currLoss = GLMOptimUtils.computeElasticNetLoss(rTrain, w, lambda1, lambda2); if (verbose) { System.out.println("Iteration " + iter + ": " + " " + currLoss); } if (OptimUtils.isConverged(prevLoss, currLoss, epsilon)) { break; } } return GLMOptimUtils.getGLM(attrs, w, intercept, LinkFunction.IDENTITY); } /** * Builds an elastic-net penalized regressor on sparse inputs. Each row of the input represents a feature (instead * of a data point), i.e., in column-oriented format. This procedure does not assume the data is normalized or * centered. * * @param attrs the attribute list. * @param indices the indices. * @param values the values. * @param y the targets. * @param maxNumIters the maximum number of iterations. * @param lambda the lambda. * @param l1Ratio the L1 ratio. * @return an elastic-net penalized regressor. */ public GLM buildGaussianRegressor(int[] attrs, int[][] indices, double[][] values, double[] y, int maxNumIters, double lambda, double l1Ratio) { double[] w = new double[attrs.length]; double intercept = 0; // Initialize residuals double[] rTrain = new double[y.length]; for (int i = 0; i < rTrain.length; i++) { rTrain[i] = y[i]; } // Calculate sum of squares double[] sq = new double[attrs.length]; for (int i = 0; i < values.length; i++) { sq[i] = StatUtils.sumSq(values[i]); } final double lambda1 = lambda * l1Ratio; final double lambda2 = lambda * (1 - l1Ratio); final double tl1 = lambda1 * y.length; final double tl2 = lambda2 * y.length; // Coordinate descent for (int iter = 0; iter < maxNumIters; iter++) { double prevLoss = GLMOptimUtils.computeElasticNetLoss(rTrain, w, lambda1, lambda2); if (fitIntercept) { intercept += OptimUtils.fitIntercept(rTrain); } doOnePassGaussian(indices, values, sq, tl1, tl2, w, rTrain); double currLoss = GLMOptimUtils.computeElasticNetLoss(rTrain, w, lambda1, lambda2); if (verbose) { System.out.println("Iteration " + iter + ": " + " " + currLoss); } if (OptimUtils.isConverged(prevLoss, currLoss, epsilon)) { break; } } return GLMOptimUtils.getGLM(attrs, w, intercept, LinkFunction.IDENTITY); } /** * Builds elastic-net penalized regressors for a sequence of regularization parameter lambdas. * * @param trainSet the training set. * @param isSparse {@code true} if the training set is treated as sparse. * @param maxNumIters the maximum number of iterations. * @param numLambdas the number of lambdas. * @param minLambdaRatio the minimum lambda is minLambdaRatio * max lambda. * @param l1Ratio the L1 ratio. * @return elastic-net penalized regressors. */ public GLM[] buildGaussianRegressors(Instances trainSet, boolean isSparse, int maxNumIters, int numLambdas, double minLambdaRatio, double l1Ratio) { if (isSparse) { SparseDataset sd = getSparseDataset(trainSet, true); double[] cList = sd.cList; GLM[] glms = buildGaussianRegressors(sd.attrs, sd.indices, sd.values, sd.y, maxNumIters, numLambdas, minLambdaRatio, l1Ratio); for (GLM glm : glms) { double[] w = glm.w[0]; for (int j = 0; j < cList.length; j++) { int attIndex = sd.attrs[j]; w[attIndex] *= cList[j]; } } return glms; } else { DenseDataset dd = getDenseDataset(trainSet, true); double[] cList = dd.cList; GLM[] glms = buildGaussianRegressors(dd.attrs, dd.x, dd.y, maxNumIters, numLambdas, minLambdaRatio, l1Ratio); for (GLM glm : glms) { double[] w = glm.w[0]; for (int j = 0; j < cList.length; j++) { int attIndex = dd.attrs[j]; w[attIndex] *= cList[j]; } } return glms; } } /** * Builds elastic-net penalized regressors for a sequence of regularization parameter lambdas. * * @param trainSet the training set. * @param maxNumIters the maximum number of iterations. * @param numLambdas the number of lambdas. * @param minLambdaRatio the minimum lambda is minLambdaRatio * max lambda. * @param l1Ratio the L1 ratio. * @return elastic-net penalized regressors. */ public GLM[] buildGaussianRegressors(Instances trainSet, int maxNumIters, int numLambdas, double minLambdaRatio, double l1Ratio) { return buildGaussianRegressors(trainSet, isSparse(trainSet), maxNumIters, numLambdas, minLambdaRatio, l1Ratio); } /** * Builds elastic-net penalized regressors for a sequence of regularization parameter lambdas. Each row in the input * matrix x represents a feature (instead of a data point). Thus the input matrix is the transpose of the * row-oriented data matrix. This procedure does not assume the data is normalized or centered. * * @param attrs the attribute list. * @param x the inputs. * @param y the targets. * @param maxNumIters the maximum number of iterations. * @param numLambdas the number of lambdas. * @param minLambdaRatio the minimum lambda is minLambdaRatio * max lambda. * @param l1Ratio the L1 ratio. * @return elastic-net penalized regressors. */ public GLM[] buildGaussianRegressors(int[] attrs, double[][] x, double[] y, int maxNumIters, int numLambdas, double minLambdaRatio, double l1Ratio) { double[] w = new double[attrs.length]; double intercept = 0; GLM[] glms = new GLM[numLambdas]; // Backup targets double[] rTrain = new double[y.length]; for (int i = 0; i < rTrain.length; i++) { rTrain[i] = y[i]; } // Calculate sum of squares double[] sq = new double[x.length]; for (int i = 0; i < x.length; i++) { sq[i] = StatUtils.sumSq(x[i]); } // Determine max lambda double maxLambda = findMaxLambdaGaussian(x, y, l1Ratio); // Dampening factor for lambda double alpha = Math.pow(minLambdaRatio, 1.0 / numLambdas); // Compute the regularization path double lambda = maxLambda; for (int g = 0; g < glms.length; g++) { final double lambda1 = lambda * l1Ratio; final double lambda2 = lambda * (1 - l1Ratio); final double tl1 = lambda1 * y.length; final double tl2 = lambda2 * y.length; // Coordinate descent for (int iter = 0; iter < maxNumIters; iter++) { double prevLoss = GLMOptimUtils.computeElasticNetLoss(rTrain, w, lambda1, lambda2); if (fitIntercept) { intercept += OptimUtils.fitIntercept(rTrain); } doOnePassGaussian(x, sq, tl1, tl2, w, rTrain); double currLoss = GLMOptimUtils.computeElasticNetLoss(rTrain, w, lambda1, lambda2); if (verbose) { System.out.println("Iteration " + iter + ": " + currLoss); } if (OptimUtils.isConverged(prevLoss, currLoss, epsilon)) { break; } } lambda *= alpha; glms[g] = GLMOptimUtils.getGLM(attrs, w, intercept, LinkFunction.IDENTITY); } return glms; } /** * Builds elastic-net penalized regressors on sparse inputs for a sequence of regularization parameter lambdas. Each row of the input * represents a feature (instead of a data point), i.e., in column-oriented format. This procedure does not assume * the data is normalized or centered. * * @param attrs the attribute list. * @param indices the indices. * @param values the values. * @param y the targets. * @param maxNumIters the maximum number of iterations. * @param numLambdas the number of lambdas. * @param minLambdaRatio the minimum lambda is minLambdaRatio * max lambda. * @param l1Ratio the L1 ratio. * @return elastic-net penalized regressors. */ public GLM[] buildGaussianRegressors(int[] attrs, int[][] indices, double[][] values, double[] y, int maxNumIters, int numLambdas, double minLambdaRatio, double l1Ratio) { double[] w = new double[attrs.length]; double intercept = 0; GLM[] glms = new GLM[numLambdas]; // Backup targets double[] rTrain = new double[y.length]; for (int i = 0; i < rTrain.length; i++) { rTrain[i] = y[i]; } // Calculate sum of squares double[] sq = new double[values.length]; for (int i = 0; i < values.length; i++) { sq[i] = StatUtils.sumSq(values[i]); } // Determine max lambda double maxLambda = findMaxLambdaGaussian(indices, values, y, l1Ratio); // Dampening factor for lambda double alpha = Math.pow(minLambdaRatio, 1.0 / numLambdas); // Compute the regularization path double lambda = maxLambda; for (int g = 0; g < glms.length; g++) { final double lambda1 = lambda * l1Ratio; final double lambda2 = lambda * (1 - l1Ratio); final double tl1 = lambda1 * y.length; final double tl2 = lambda2 * y.length; // Coordinate descent for (int iter = 0; iter < maxNumIters; iter++) { double prevLoss = GLMOptimUtils.computeElasticNetLoss(rTrain, w, lambda1, lambda2); if (fitIntercept) { intercept += OptimUtils.fitIntercept(rTrain); } doOnePassGaussian(indices, values, sq, tl1, tl2, w, rTrain); double currLoss = GLMOptimUtils.computeElasticNetLoss(rTrain, w, lambda1, lambda2); if (verbose) { System.out.println("Iteration " + iter + ": " + currLoss); } if (OptimUtils.isConverged(prevLoss, currLoss, epsilon)) { break; } } lambda *= alpha; glms[g] = GLMOptimUtils.getGLM(attrs, w, intercept, LinkFunction.IDENTITY); } return glms; } protected void doOnePassGaussian(double[][] x, double[] sq, final double tl1, final double tl2, double[] w, double[] rTrain) { for (int j = 0; j < x.length; j++) { double[] v = x[j]; // Calculate weight updates using naive updates double wNew = w[j] * sq[j] + VectorUtils.dotProduct(v, rTrain); if (Math.abs(wNew) <= tl1) { wNew = 0; } else if (wNew > 0) { wNew -= tl1; } else { wNew += tl1; } wNew /= (sq[j] + tl2); double delta = wNew - w[j]; w[j] = wNew; // Update residuals for (int i = 0; i < rTrain.length; i++) { rTrain[i] -= delta * v[i]; } } } protected void doOnePassGaussian(int[][] indices, double[][] values, double[] sq, final double tl1, final double tl2, double[] w, double[] rTrain) { for (int j = 0; j < indices.length; j++) { // Calculate weight updates using naive updates double wNew = w[j] * sq[j]; int[] index = indices[j]; double[] value = values[j]; for (int i = 0; i < index.length; i++) { wNew += rTrain[index[i]] * value[i]; } if (Math.abs(wNew) <= tl1) { wNew = 0; } else if (wNew > 0) { wNew -= tl1; } else { wNew += tl1; } wNew /= (sq[j] + tl2); double delta = wNew - w[j]; w[j] = wNew; // Update residuals for (int i = 0; i < index.length; i++) { rTrain[index[i]] -= delta * value[i]; } } } protected void doOnePassBinomial(double[][] x, double[] theta, double[] y, final double tl1, final double tl2, double[] w, double[] pTrain, double[] rTrain) { for (int j = 0; j < x.length; j++) { if (Math.abs(theta[j]) <= MathUtils.EPSILON) { continue; } double[] v = x[j]; double eta = VectorUtils.dotProduct(rTrain, v); double newW = w[j] * theta[j] + eta; if (newW > tl1) { newW -= tl1; } else if (newW < -tl1) { newW += tl1; } else { newW = 0; } newW /= (theta[j] + tl2); double delta = newW - w[j]; w[j] = newW; // Update predictions for (int i = 0; i < pTrain.length; i++) { pTrain[i] += delta * v[i]; rTrain[i] = OptimUtils.getPseudoResidual(pTrain[i], y[i]); } } } protected void doOnePassBinomial(int[][] indices, double[][] values, double[] theta, double[] y, final double tl1, final double tl2, double[] w, double[] pTrain, double[] rTrain) { for (int j = 0; j < indices.length; j++) { if (Math.abs(theta[j]) <= MathUtils.EPSILON) { continue; } double eta = 0; int[] index = indices[j]; double[] value = values[j]; for (int i = 0; i < index.length; i++) { int idx = index[i]; eta += rTrain[idx] * value[i]; } double newW = w[j] * theta[j] + eta; if (newW > tl1) { newW -= tl1; } else if (newW < -tl1) { newW += tl1; } else { newW = 0; } newW /= (theta[j] + tl2); double delta = newW - w[j]; w[j] = newW; // Update predictions for (int i = 0; i < index.length; i++) { int idx = index[i]; pTrain[idx] += delta * value[i]; rTrain[idx] = OptimUtils.getPseudoResidual(pTrain[idx], y[idx]); } } } protected double findMaxLambdaGaussian(double[][] x, double[] y, double l1Ratio) { double mean = 0; if (fitIntercept) { mean = OptimUtils.fitIntercept(y); } // Determine max lambda double maxLambda = 0; for (double[] col : x) { double dot = Math.abs(VectorUtils.dotProduct(col, y)); if (dot > maxLambda) { maxLambda = dot; } } maxLambda /= y.length; maxLambda /= l1Ratio; if (fitIntercept) { VectorUtils.add(y, mean); } return maxLambda; } protected double findMaxLambdaGaussian(int[][] indices, double[][] values, double[] y, double l1Ratio) { double mean = 0; if (fitIntercept) { mean = OptimUtils.fitIntercept(y); } DenseVector v = new DenseVector(y); // Determine max lambda double maxLambda = 0; for (int i = 0; i < indices.length; i++) { int[] index = indices[i]; double[] value = values[i]; double dot = Math.abs(VectorUtils.dotProduct(new SparseVector(index, value), v)); if (dot > maxLambda) { maxLambda = dot; } } maxLambda /= y.length; maxLambda /= l1Ratio; if (fitIntercept) { VectorUtils.add(y, mean); } return maxLambda; } protected double findMaxLambdaBinomial(double[][] x, double[] y, double[] pTrain, double[] rTrain, double l1Ratio) { if (fitIntercept) { OptimUtils.fitIntercept(pTrain, rTrain, y); } double maxLambda = 0; for (double[] col : x) { double eta = 0; for (int i = 0; i < col.length; i++) { eta += rTrain[i] * col[i]; } double t = Math.abs(eta); if (t > maxLambda) { maxLambda = t; } } maxLambda /= y.length; maxLambda /= l1Ratio; if (fitIntercept) { Arrays.fill(pTrain, 0); OptimUtils.computePseudoResidual(pTrain, y, rTrain); } return maxLambda; } protected double findMaxLambdaBinomial(int[][] indices, double[][] values, double[] y, double[] pTrain, double[] rTrain, double l1Ratio) { if (fitIntercept) { OptimUtils.fitIntercept(pTrain, rTrain, y); } double maxLambda = 0; for (int k = 0; k < values.length; k++) { double eta = 0; int[] index = indices[k]; double[] value = values[k]; for (int i = 0; i < index.length; i++) { int idx = index[i]; double r = OptimUtils.getPseudoResidual(pTrain[idx], y[idx]); r *= value[i]; eta += r; } double t = Math.abs(eta); if (t > maxLambda) { maxLambda = t; } } maxLambda /= y.length; maxLambda /= l1Ratio; if (fitIntercept) { Arrays.fill(pTrain, 0); OptimUtils.computePseudoResidual(pTrain, y, rTrain); } return maxLambda; } /** * Returns the l1 ratio. * * @return the l1 ratio. */ public double getL1Ratio() { return l1Ratio; } /** * Returns the lambda. * * @return the lambda. */ public double getLambda() { return lambda; } /** * Returns the task of this learner. * * @return the task of this learner. */ public Task getTask() { return task; } /** * Sets the l1 ratio. * * @param l1Ratio the l1 ratio. */ public void setL1Ratio(double l1Ratio) { this.l1Ratio = l1Ratio; } /** * Sets the lambda. * * @param lambda the lambda. */ public void setLambda(double lambda) { this.lambda = lambda; } /** * Sets the task of this learner. * * @param task the task of this learner. */ public void setTask(Task task) { this.task = task; } } ================================================ FILE: src/main/java/mltk/predictor/glm/GLM.java ================================================ package mltk.predictor.glm; import java.io.BufferedReader; import java.io.PrintWriter; import java.util.Arrays; import mltk.core.Instance; import mltk.core.SparseVector; import mltk.predictor.LinkFunction; import mltk.predictor.ProbabilisticClassifier; import mltk.predictor.Regressor; import mltk.util.ArrayUtils; import mltk.util.MathUtils; import mltk.util.StatUtils; /** * Class for generalized linear models (GLMs). * * @author Yin Lou * */ public class GLM implements ProbabilisticClassifier, Regressor { /** * The coefficient vectors. */ protected double[][] w; /** * The intercept vector. */ protected double[] intercept; /** * The link function. */ protected LinkFunction link; /** * Constructor. */ public GLM() { } /** * Constructs a GLM with the specified dimension. * * @param dimension the dimension. */ public GLM(int dimension) { this(1, dimension); } /** * Constructs a GLM with the specified dimension. * * @param numClasses the number of classes. * @param dimension the dimension. */ public GLM(int numClasses, int dimension) { w = new double[numClasses][dimension]; intercept = new double[numClasses]; } /** * Constructs a GLM with the intercept vector and the coefficient vectors. * * @param intercept the intercept vector. * @param w the coefficient vectors. */ public GLM(double[] intercept, double[][] w) { this(intercept, w, LinkFunction.IDENTITY); } /** * Constructs a GLM with the intercept vector, the coefficient vectors and its link function. * * @param intercept the intercept vector. * @param w the coefficient vectors. * @param link the link function. */ public GLM(double[] intercept, double[][] w, LinkFunction link) { if (intercept.length != w.length) { throw new IllegalArgumentException("Dimensions of intercept and w must match."); } this.intercept = intercept; this.w = w; this.link = link; } /** * Returns the coefficient vectors. * * @return the coefficient vectors. */ public double[][] coefficients() { return w; } /** * Returns the coefficient vectors for class k. * * @param k the index of the class. * @return the coefficient vectors for class k. */ public double[] coefficients(int k) { return w[k]; } /** * Returns the intercept vector. * * @return the intercept vector. */ public double[] intercept() { return intercept; } /** * Returns the intercept for class k. * * @param k the index of the class. * @return the intercept for class k. */ public double intercept(int k) { return intercept[k]; } @Override public void read(BufferedReader in) throws Exception { link = LinkFunction.get(in.readLine().split(": ")[1]); in.readLine(); intercept = ArrayUtils.parseDoubleArray(in.readLine()); int p = Integer.parseInt(in.readLine().split(": ")[1]); w = new double[intercept.length][p]; for (int j = 0; j < p; j++) { String[] data = in.readLine().split("\\s+"); for (int i = 0; i < w.length; i++) { w[i][j] = Double.parseDouble(data[i]); } } } @Override public void write(PrintWriter out) throws Exception { out.printf("[Predictor: %s]\n", this.getClass().getCanonicalName()); out.println("Link: " + link); out.println("Intercept: " + intercept.length); out.println(Arrays.toString(intercept)); final int p = w[0].length; out.println("Coefficients: " + p); for (int j = 0; j < p; j++) { out.print(w[0][j]); for (int i = 1; i < w.length; i++) { out.print(" " + w[i][j]); } out.println(); } } @Override public double regress(Instance instance) { return regress(intercept[0], w[0], instance); } @Override public int classify(Instance instance) { double[] prob = predictProbabilities(instance); return StatUtils.indexOfMax(prob); } /** * Returns the prediction of this GLM on the scale of the response variable. * * @param instance the instance to predict. * @return the prediction of this GLM on the scale of the response variable. */ public double predict(Instance instance) { return link.applyInverse(regress(instance)); } @Override public double[] predictProbabilities(Instance instance) { if (w.length == 1) { double[] prob = new double[2]; double pred = regress(intercept[0], w[0], instance); prob[0] = MathUtils.sigmoid(pred); prob[1] = 1 - prob[0]; return prob; } else { double[] prob = new double[w.length]; double[] pred = new double[w.length]; double sum = 0; for (int i = 0; i < w.length; i++) { pred[i] = regress(intercept[i], w[i], instance); prob[i] = MathUtils.sigmoid(pred[i]); sum += prob[i]; } for (int i = 0; i < prob.length; i++) { prob[i] /= sum; } return prob; } } @Override public GLM copy() { double[][] copyW = new double[w.length][]; for (int i = 0; i < copyW.length; i++) { copyW[i] = Arrays.copyOf(w[i], w[i].length); } return new GLM(intercept, copyW, link); } protected double regress(double intercept, double[] coef, Instance instance) { if (!instance.isSparse()) { double pred = intercept; for (int i = 0; i < coef.length; i++) { pred += coef[i] * instance.getValue(i); } return pred; } else { double pred = intercept; SparseVector vector = (SparseVector) instance.getVector(); int[] indices = vector.getIndices(); double[] values = vector.getValues(); for (int i = 0; i < indices.length; i++) { int index = indices[i]; if (index < coef.length) { pred += coef[index] * values[i]; } } return pred; } } } ================================================ FILE: src/main/java/mltk/predictor/glm/GLMLearner.java ================================================ package mltk.predictor.glm; import mltk.core.Instances; import mltk.predictor.Family; import mltk.predictor.Learner; import mltk.util.MathUtils; /** * Abstract class for learning generalized linear models (GLMs). * * @author Yin Lou * */ public abstract class GLMLearner extends Learner { protected boolean fitIntercept; protected int maxNumIters; protected double epsilon; protected Family family; /** * Constructor. */ public GLMLearner() { verbose = false; fitIntercept = true; maxNumIters = -1; epsilon = MathUtils.EPSILON; family = Family.GAUSSIAN; } /** * Returns {@code true} if we fit intercept. * * @return {@code true} if we fit intercept. */ public boolean fitIntercept() { return fitIntercept; } /** * Sets whether we fit intercept. * * @param fitIntercept whether we fit intercept. */ public void fitIntercept(boolean fitIntercept) { this.fitIntercept = fitIntercept; } /** * Returns the convergence threshold epsilon. * * @return the convergence threshold epsilon. */ public double getEpsilon() { return epsilon; } /** * Sets the convergence threshold epsilon. * * @param epsilon the convergence threshold epsilon. */ public void setEpsilon(double epsilon) { this.epsilon = epsilon; } /** * Returns the maximum number of iterations. * * @return the maximum number of iterations. */ public int getMaxNumIters() { return maxNumIters; } /** * Sets the maximum number of iterations. * * @param maxNumIters the maximum number of iterations. */ public void setMaxNumIters(int maxNumIters) { this.maxNumIters = maxNumIters; } /** * Returns the response distribution family. * * @return the response distribution family. */ public Family getFamily() { return family; } /** * Sets the response distribution family. * * @param family the response distribution family. */ public void setFamily(Family family) { this.family = family; } /** * Builds a generalized linear model given response distribution family. * The default link function for the family will be used. * * @param trainSet the training set. * @param family the response distribution family. * @return a generalized linear model. */ public abstract GLM build(Instances trainSet, Family family); } ================================================ FILE: src/main/java/mltk/predictor/glm/GLMOptimUtils.java ================================================ package mltk.predictor.glm; import mltk.predictor.LinkFunction; import mltk.util.OptimUtils; import mltk.util.StatUtils; import mltk.util.VectorUtils; class GLMOptimUtils { static GLM getGLM(int[] attrs, double[] w, double intercept, LinkFunction link) { final int p = attrs.length == 0 ? 0 : (StatUtils.max(attrs) + 1); GLM glm = new GLM(p); for (int i = 0; i < attrs.length; i++) { glm.w[0][attrs[i]] = w[i]; } glm.intercept[0] = intercept; glm.link = link; return glm; } static double computeRidgeLoss(double[] residual, double[] w, double lambda) { double loss = OptimUtils.computeQuadraticLoss(residual); loss += lambda / 2 * StatUtils.sumSq(w); return loss; } static double computeRidgeLoss(double[] pred, double[] y, double[] w, double lambda) { double loss = OptimUtils.computeLogisticLoss(pred, y); loss += lambda / 2 * StatUtils.sumSq(w); return loss; } static double computeLassoLoss(double[] residual, double[] w, double lambda) { double loss = OptimUtils.computeQuadraticLoss(residual); loss += lambda * VectorUtils.l1norm(w); return loss; } static double computeLassoLoss(double[] pred, double[] y, double[] w, double lambda) { double loss = OptimUtils.computeLogisticLoss(pred, y); loss += lambda * VectorUtils.l1norm(w); return loss; } static double computeElasticNetLoss(double[] residual, double[] w, double lambda1, double lambda2) { double loss = OptimUtils.computeQuadraticLoss(residual); loss += lambda1 * VectorUtils.l1norm(w) + lambda2 / 2 * StatUtils.sumSq(w); return loss; } static double computeElasticNetLoss(double[] pred, double[] y, double[] w, double lambda1, double lambda2) { double loss = OptimUtils.computeLogisticLoss(pred, y); loss += lambda1 * VectorUtils.l1norm(w) + lambda2 / 2 * StatUtils.sumSq(w); return loss; } static double computeGroupLassoLoss(double[] residual, double[][] w, double[] tl1) { double loss = OptimUtils.computeQuadraticLoss(residual); for (int k = 0; k < w.length; k++) { loss += tl1[k] * StatUtils.sumSq(w[k]); } return loss; } static double computeGroupLassoLoss(double[] pred, double[] y, double[][] w, double[] tl1) { double loss = OptimUtils.computeLogisticLoss(pred, y); for (int k = 0; k < w.length; k++) { loss += tl1[k] * StatUtils.sumSq(w[k]); } return loss; } } ================================================ FILE: src/main/java/mltk/predictor/glm/GroupLassoLearner.java ================================================ package mltk.predictor.glm; import java.util.ArrayList; import java.util.Arrays; import java.util.HashMap; import java.util.HashSet; import java.util.List; import java.util.Map; import java.util.Set; import mltk.core.Attribute; import mltk.core.Instances; import mltk.core.NominalAttribute; import mltk.predictor.Family; import mltk.predictor.LinkFunction; import mltk.predictor.glm.GLM; import mltk.util.ArrayUtils; import mltk.util.OptimUtils; import mltk.util.StatUtils; import mltk.util.VectorUtils; /** * Class for learning group-lasso models via block coordinate gradient descent. * *

* Reference:
* M Yuan and Y Lin. Model selection and estimation in regression with grouped variables. In * Journal of the Royal Statistical Society: Series B (Statistical Methodology), * 68(1):49-67, 2006. *

* * @author Yin Lou * */ public class GroupLassoLearner extends GLMLearner { class DenseDesignMatrix { int[][] groups; double[][][] x; DenseDesignMatrix(int[][] groups, double[][][] x) { this.groups = groups; this.x = x; } } static class ModelStructure { boolean[] structure; ModelStructure(boolean[] structure) { this.structure = structure; } @Override public boolean equals(Object obj) { if (this == obj) return true; if (obj == null) return false; if (getClass() != obj.getClass()) return false; ModelStructure other = (ModelStructure) obj; if (!Arrays.equals(structure, other.structure)) return false; return true; } @Override public int hashCode() { final int prime = 31; int result = 1; result = prime * result + Arrays.hashCode(structure); return result; } } class SparseDesignMatrix { int[][] group; int[][][] indices; double[][][] values; SparseDesignMatrix(int[][] groups, int[][][] indices, double[][][] values) { this.indices = indices; this.values = values; } } protected boolean refit; protected int numLambdas; protected double lambda; protected Task task; protected List groups; /** * Constructor. */ public GroupLassoLearner() { refit = false; lambda = 0.0; task = Task.REGRESSION; groups = null; } @Override public GLM build(Instances instances) { if (groups == null) { throw new IllegalArgumentException("Groups are not set."); } GLM glm = null; if (maxNumIters < 0) { maxNumIters = 20; } switch (task) { case REGRESSION: glm = buildGaussianRegressor(instances, groups, maxNumIters, lambda); break; case CLASSIFICATION: glm = buildClassifier(instances, groups, maxNumIters, lambda); break; default: break; } return glm; } @Override public GLM build(Instances trainSet, Family family) { if (groups == null) { throw new IllegalArgumentException("Groups are not set."); } GLM glm = null; if (maxNumIters < 0) { maxNumIters = 20; } switch (family) { case GAUSSIAN: glm = buildGaussianRegressor(trainSet, groups, maxNumIters, lambda); break; case BINOMIAL: glm = buildClassifier(trainSet, groups, maxNumIters, lambda); break; default: throw new IllegalArgumentException("Unsupported family: " + family); } return glm; } /** * Builds a group-lasso penalized binary classifier. The input matrix is grouped by groups. This procedure does not * assume the data is normalized or centered. * * @param attrs the groups of variables. * @param x the inputs. * @param y the targets. * @param maxNumIters the maximum number of iterations. * @param lambda the lambda. * @return a group-lasso penalized classifier. */ public GLM buildBinaryClassifier(int[][] attrs, double[][][] x, double[] y, int maxNumIters, double lambda) { int p = 0; if (attrs.length > 0) { for (int[] attr : attrs) { p = Math.max(p, StatUtils.max(attr)); } p += 1; } double[][] w = new double[attrs.length][]; double[] tl1 = new double[attrs.length]; int m = 0; for (int j = 0; j < attrs.length; j++) { w[j] = new double[attrs[j].length]; tl1[j] = lambda * Math.sqrt(w[j].length); if (w[j].length > m) { m = w[j].length; } } double[] pTrain = new double[y.length]; double[] rTrain = new double[y.length]; OptimUtils.computePseudoResidual(pTrain, y, rTrain); double[] stepSize = new double[attrs.length]; for (int j = 0; j < stepSize.length; j++) { double max = 0; double[][] block = x[j]; for (double[] t : block) { double l = StatUtils.sumSq(t) / 4; if (l > max) { max = l; } } stepSize[j] = 1.0 / max; } double[] g = new double[m]; double[] gradient = new double[m]; boolean[] activeSet = new boolean[attrs.length]; double intercept = 0; // Block coordinate gradient descent int iter = 0; while (iter < maxNumIters) { if (fitIntercept) { intercept += OptimUtils.fitIntercept(pTrain, rTrain, y); } boolean activeSetChanged = doOnePassBinomial(x, y, tl1, true, activeSet, w, stepSize, g, gradient, pTrain, rTrain); iter++; if (!activeSetChanged || iter > maxNumIters) { break; } for (; iter < maxNumIters; iter++) { double prevLoss = GLMOptimUtils.computeGroupLassoLoss(pTrain, y, w, tl1); if (fitIntercept) { intercept += OptimUtils.fitIntercept(pTrain, rTrain, y); } doOnePassBinomial(x, y, tl1, false, activeSet, w, stepSize, g, gradient, pTrain, rTrain); double currLoss = GLMOptimUtils.computeGroupLassoLoss(pTrain, y, w, tl1); if (OptimUtils.isConverged(prevLoss, currLoss, epsilon)) { break; } if (verbose) { System.out.println("Iteration " + iter + ": " + currLoss); } } } if (refit) { boolean[] selected = new boolean[attrs.length]; for (int i = 0; i < selected.length; i++) { selected[i] = !ArrayUtils.isConstant(w[i], 0, w[i].length, 0); } return refitGaussianRegressor(p, attrs, selected, x, y, w, maxNumIters); } else { return getGLM(p, attrs, w, intercept, LinkFunction.LOGIT); } } /** * Builds a group-lasso penalized binary classifier on sparse inputs. The input matrix is grouped by groups. This procedure does not * assume the data is normalized or centered. * * @param attrs the groups of variables. * @param indices the indices. * @param values the values. * @param y the targets. * @param maxNumIters the maximum number of iterations. * @param lambda the lambda. * @return a group-lasso penalized classifier. */ public GLM buildBinaryClassifier(int[][] attrs, int[][][] indices, double[][][] values, double[] y, int maxNumIters, double lambda) { int p = 0; if (attrs.length > 0) { for (int[] attr : attrs) { p = Math.max(p, StatUtils.max(attr)); } p += 1; } int[][] indexUnion = new int[attrs.length][]; for (int g = 0; g < attrs.length; g++) { int[][] index = indices[g]; Set set = new HashSet<>(); for (int[] idx : index) { for (int i : idx) { set.add(i); } } int[] idxUnion = new int[set.size()]; int k = 0; for (int idx : set) { idxUnion[k++] = idx; } indexUnion[g] = idxUnion; } double[][] w = new double[attrs.length][]; double[] tl1 = new double[attrs.length]; int m = 0; for (int j = 0; j < attrs.length; j++) { w[j] = new double[attrs[j].length]; tl1[j] = lambda * Math.sqrt(w[j].length); if (w[j].length > m) { m = w[j].length; } } double[] pTrain = new double[y.length]; double[] rTrain = new double[y.length]; OptimUtils.computePseudoResidual(pTrain, y, rTrain); double[] stepSize = new double[attrs.length]; for (int j = 0; j < stepSize.length; j++) { double max = 0; double[][] block = values[j]; for (double[] t : block) { double l = StatUtils.sumSq(t) / 4; if (l > max) { max = l; } } stepSize[j] = 1.0 / max; } double[] g = new double[m]; double[] gradient = new double[m]; boolean[] activeSet = new boolean[values.length]; double intercept = 0; // Block coordinate gradient descent int iter = 0; while (iter < maxNumIters) { if (fitIntercept) { intercept += OptimUtils.fitIntercept(pTrain, rTrain, y); } boolean activeSetChanged = doOnePassBinomial(indices, indexUnion, values, y, tl1, true, activeSet, w, stepSize, g, gradient, pTrain, rTrain); iter++; if (!activeSetChanged || iter > maxNumIters) { break; } for (; iter < maxNumIters; iter++) { double prevLoss = GLMOptimUtils.computeGroupLassoLoss(pTrain, y, w, tl1); if (fitIntercept) { intercept += OptimUtils.fitIntercept(pTrain, rTrain, y); } doOnePassBinomial(indices, indexUnion, values, y, tl1, true, activeSet, w, stepSize, g, gradient, pTrain, rTrain); double currLoss = GLMOptimUtils.computeGroupLassoLoss(pTrain, y, w, tl1); if (OptimUtils.isConverged(prevLoss, currLoss, epsilon)) { break; } if (verbose) { System.out.println("Iteration " + iter + ": " + currLoss); } } } if (refit) { boolean[] selected = new boolean[attrs.length]; for (int i = 0; i < selected.length; i++) { selected[i] = !ArrayUtils.isConstant(w[i], 0, w[i].length, 0); } return refitClassifier(p, attrs, selected, indices, values, y, w, maxNumIters); } else { return getGLM(p, attrs, w, intercept, LinkFunction.LOGIT); } } /** * Builds group-lasso penalized binary classifiers for a sequence of regularization parameter lambdas. The input matrix is grouped by groups. This procedure does not * assume the data is normalized or centered. * * @param attrs the groups of variables. * @param x the inputs. * @param y the targets. * @param maxNumIters the maximum number of iterations. * @param numLambdas the number of lambdas. * @param minLambdaRatio the minimum lambda is minLambdaRatio * max lambda. * @return group-lasso penalized classifiers. */ public List buildBinaryClassifiers(int[][] attrs, double[][][] x, double[] y, int maxNumIters, int numLambdas, double minLambdaRatio) { int p = 0; if (attrs.length > 0) { for (int[] attr : attrs) { p = Math.max(p, StatUtils.max(attr)); } p += 1; } double[][] w = new double[attrs.length][]; int m = 0; for (int j = 0; j < attrs.length; j++) { w[j] = new double[attrs[j].length]; if (w[j].length > m) { m = w[j].length; } } double[] pTrain = new double[y.length]; double[] rTrain = new double[y.length]; OptimUtils.computePseudoResidual(pTrain, y, rTrain); double[] g = new double[m]; double[] gradient = new double[m]; double[] tl1 = new double[x.length]; double[] stepSize = new double[x.length]; for (int j = 0; j < stepSize.length; j++) { double max = 0; double[][] block = x[j]; for (double[] t : block) { double l = StatUtils.sumSq(t) / 4; if (l > max) { max = l; } } stepSize[j] = 1.0 / max; } boolean[] activeSet = new boolean[x.length]; double maxLambda = findMaxLambdaBinomial(x, y, pTrain, rTrain, gradient); // Dampening factor for lambda double alpha = Math.pow(minLambdaRatio, 1.0 / numLambdas); // Compute the regularization path List glms = new ArrayList<>(numLambdas); double lambda = maxLambda; double intercept = 0; Set structures = new HashSet<>(); for (int l = 0; l < numLambdas; l++) { // Initialize regularization parameters for (int j = 0; j < tl1.length; j++) { tl1[j] = lambda * Math.sqrt(w[j].length); } // Block coordinate gradient descent int iter = 0; while (iter < maxNumIters) { if (fitIntercept) { intercept += OptimUtils.fitIntercept(pTrain, rTrain, y); } boolean activeSetChanged = doOnePassBinomial(x, y, tl1, true, activeSet, w, stepSize, g, gradient, pTrain, rTrain); iter++; if (!activeSetChanged || iter > maxNumIters) { break; } for (; iter < maxNumIters; iter++) { double prevLoss = GLMOptimUtils.computeGroupLassoLoss(pTrain, y, w, tl1); if (fitIntercept) { intercept += OptimUtils.fitIntercept(pTrain, rTrain, y); } doOnePassBinomial(x, y, tl1, false, activeSet, w, stepSize, g, gradient, pTrain, rTrain); double currLoss = GLMOptimUtils.computeGroupLassoLoss(pTrain, y, w, tl1); if (OptimUtils.isConverged(prevLoss, currLoss, epsilon)) { break; } if (verbose) { System.out.println("Iteration " + iter + ": " + currLoss); } } } lambda *= alpha; if (refit) { boolean[] selected = new boolean[attrs.length]; for (int i = 0; i < selected.length; i++) { selected[i] = !ArrayUtils.isConstant(w[i], 0, w[i].length, 0); } ModelStructure structure = new ModelStructure(selected); if (!structures.contains(structure)) { GLM glm = refitClassifier(p, attrs, selected, x, y, w, maxNumIters); glms.add(glm); structures.add(structure); } } else { GLM glm = getGLM(p, attrs, w, intercept, LinkFunction.LOGIT); glms.add(glm); } } return glms; } /** * Builds group-lasso penalized binary classifiers on sparse inputs for a sequence of regularization parameter lambdas. The input matrix is grouped by groups. This procedure does not * assume the data is normalized or centered. * * @param attrs the groups of variables. * @param indices the indices. * @param values the values. * @param y the targets. * @param maxNumIters the maximum number of iterations. * @param numLambdas the number of lambdas. * @param minLambdaRatio the minimum lambda is minLambdaRatio * max lambda. * @return group-lasso penalized classifiers. */ public List buildBinaryClassifiers(int[][] attrs, int[][][] indices, double[][][] values, double[] y, int maxNumIters, int numLambdas, double minLambdaRatio) { int p = 0; if (attrs.length > 0) { for (int[] attr : attrs) { p = Math.max(p, StatUtils.max(attr)); } p += 1; } int[][] indexUnion = new int[indices.length][]; for (int g = 0; g < indices.length; g++) { int[][] index = indices[g]; Set set = new HashSet<>(); for (int[] idx : index) { for (int i : idx) { set.add(i); } } int[] idxUnion = new int[set.size()]; int k = 0; for (int idx : set) { idxUnion[k++] = idx; } indexUnion[g] = idxUnion; } double[][] w = new double[attrs.length][]; double[] tl1 = new double[attrs.length]; int m = 0; for (int j = 0; j < values.length; j++) { w[j] = new double[values[j].length]; if (w[j].length > m) { m = w[j].length; } } double[] pTrain = new double[y.length]; double[] rTrain = new double[y.length]; OptimUtils.computePseudoResidual(pTrain, y, rTrain); double[] stepSize = new double[values.length]; for (int j = 0; j < stepSize.length; j++) { double max = 0; double[][] block = values[j]; for (double[] t : block) { double l = StatUtils.sumSq(t) / 4; if (l > max) { max = l; } } stepSize[j] = 1.0 / max; } double[] g = new double[m]; double[] gradient = new double[m]; boolean[] activeSet = new boolean[values.length]; double maxLambda = findMaxLambdaBinomial(indices, values, y, pTrain, rTrain, gradient); // Dampening factor for lambda double alpha = Math.pow(minLambdaRatio, 1.0 / numLambdas); List glms = new ArrayList<>(numLambdas); double lambda = maxLambda; double intercept = 0; Set structures = new HashSet<>(); for (int l = 0; l < numLambdas; l++) { // Initialize regularization parameters for (int j = 0; j < tl1.length; j++) { tl1[j] = lambda * Math.sqrt(w[j].length); } // Block coordinate gradient descent int iter = 0; while (iter < maxNumIters) { if (fitIntercept) { intercept += OptimUtils.fitIntercept(pTrain, rTrain, y); } boolean activeSetChanged = doOnePassBinomial(indices, indexUnion, values, y, tl1, true, activeSet, w, stepSize, g, gradient, pTrain, rTrain); iter++; if (!activeSetChanged || iter > maxNumIters) { break; } for (; iter < maxNumIters; iter++) { double prevLoss = GLMOptimUtils.computeGroupLassoLoss(pTrain, y, w, tl1); if (fitIntercept) { intercept += OptimUtils.fitIntercept(pTrain, rTrain, y); } doOnePassBinomial(indices, indexUnion, values, y, tl1, false, activeSet, w, stepSize, g, gradient, pTrain, rTrain); double currLoss = GLMOptimUtils.computeGroupLassoLoss(pTrain, y, w, tl1); if (OptimUtils.isConverged(prevLoss, currLoss, epsilon)) { break; } if (verbose) { System.out.println("Iteration " + iter + ": " + currLoss); } } } lambda *= alpha; if (refit) { boolean[] selected = new boolean[attrs.length]; for (int i = 0; i < selected.length; i++) { selected[i] = !ArrayUtils.isConstant(w[i], 0, w[i].length, 0); } ModelStructure structure = new ModelStructure(selected); if (!structures.contains(structure)) { GLM glm = refitClassifier(p, attrs, selected, indices, values, y, w, maxNumIters); glms.add(glm); } } else { GLM glm = getGLM(p, attrs, w, intercept, LinkFunction.LOGIT); glms.add(glm); } } return glms; } /** * Builds a group-lasso penalized classifier. * * @param trainSet the training set. * @param isSparse {@code true} if the training set is treated as sparse. * @param groups the groups of variables. * @param maxNumIters the maximum number of iterations. * @param lambda the lambda. * @return a group-lasso penalized classifier. */ public GLM buildClassifier(Instances trainSet, boolean isSparse, List groups, int maxNumIters, double lambda) { Attribute classAttribute = trainSet.getTargetAttribute(); if (classAttribute.getType() != Attribute.Type.NOMINAL) { throw new IllegalArgumentException("Class attribute must be nominal."); } NominalAttribute clazz = (NominalAttribute) classAttribute; int numClasses = clazz.getCardinality(); if (isSparse) { SparseDataset sd = getSparseDataset(trainSet, true); SparseDesignMatrix sm = createDesignMatrix(sd, groups); int[] attrs = sd.attrs; int[][] group = sm.group; int[][][] indices = sm.indices; double[][][] values = sm.values; double[] y = new double[sd.y.length]; double[] cList = sd.cList; if (numClasses == 2) { for (int i = 0; i < y.length; i++) { int label = (int) sd.y[i]; y[i] = label == 0 ? 1 : 0; } GLM glm = buildBinaryClassifier(group, indices, values, y, maxNumIters, lambda); double[] w = glm.coefficients(0); for (int j = 0; j < cList.length; j++) { int attIndex = sd.attrs[j]; w[attIndex] *= cList[j]; } return glm; } else { int p = attrs.length == 0 ? 0 : (StatUtils.max(attrs) + 1); GLM glm = new GLM(numClasses, p); for (int k = 0; k < numClasses; k++) { // One-vs-the-rest for (int i = 0; i < y.length; i++) { int label = (int) sd.y[i]; y[i] = label == k ? 1 : 0; } GLM binaryClassifier = buildBinaryClassifier(group, indices, values, y, maxNumIters, lambda); double[] w = binaryClassifier.w[0]; for (int j = 0; j < cList.length; j++) { int attIndex = attrs[j]; glm.w[k][attIndex] = w[attIndex] * cList[j]; } glm.intercept[k] = binaryClassifier.intercept[0]; } return glm; } } else { DenseDataset dd = getDenseDataset(trainSet, true); DenseDesignMatrix dm = createDesignMatrix(dd, groups); int[] attrs = dd.attrs; int[][] group = dm.groups; double[][][] x = dm.x; double[] y = new double[dd.y.length]; double[] cList = dd.cList; if (numClasses == 2) { for (int i = 0; i < y.length; i++) { int label = (int) dd.y[i]; y[i] = label == 0 ? 1 : 0; } GLM glm = buildBinaryClassifier(group, x, y, maxNumIters, lambda); double[] w = glm.coefficients(0); for (int j = 0; j < cList.length; j++) { int attIndex = dd.attrs[j]; w[attIndex] *= cList[j]; } return glm; } else { int p = attrs.length == 0 ? 0 : (StatUtils.max(attrs) + 1); GLM glm = new GLM(numClasses, p); for (int k = 0; k < numClasses; k++) { // One-vs-the-rest for (int i = 0; i < y.length; i++) { int label = (int) dd.y[i]; y[i] = label == k ? 1 : 0; } GLM binaryClassifier = buildBinaryClassifier(group, x, y, maxNumIters, lambda); double[] w = binaryClassifier.w[0]; for (int j = 0; j < cList.length; j++) { int attIndex = attrs[j]; glm.w[k][attIndex] = w[attIndex] * cList[j]; } glm.intercept[k] = binaryClassifier.intercept[0]; } return glm; } } } /** * Builds a group-lasso penalized classifier. * * @param trainSet the training set. * @param groups the groups of variables. * @param maxNumIters the maximum number of iterations. * @param lambda the lambda. * @return a group-lasso penalized classifier. */ public GLM buildClassifier(Instances trainSet, List groups, int maxNumIters, double lambda) { return buildClassifier(trainSet, isSparse(trainSet), groups, maxNumIters, lambda); } /** * Builds group-lasso penalized classifiers. * * @param trainSet the training set. * @param isSparse {@code true} if the training set is treated as sparse. * @param groups the groups of variables. * @param maxNumIters the maximum number of iterations. * @param numLambdas the number of lambdas. * @param minLambdaRatio the minimum lambda is minLambdaRatio * max lambda. * @return group-lasso penalized classifiers. */ public List buildClassifiers(Instances trainSet, boolean isSparse, List groups, int maxNumIters, int numLambdas, double minLambdaRatio) { Attribute classAttribute = trainSet.getTargetAttribute(); if (classAttribute.getType() != Attribute.Type.NOMINAL) { throw new IllegalArgumentException("Class attribute must be nominal."); } NominalAttribute clazz = (NominalAttribute) classAttribute; int numClasses = clazz.getCardinality(); if (isSparse) { SparseDataset sd = getSparseDataset(trainSet, true); SparseDesignMatrix sm = createDesignMatrix(sd, groups); int[] attrs = sd.attrs; int[][] group = sm.group; int[][][] indices = sm.indices; double[][][] values = sm.values; double[] y = new double[sd.y.length]; double[] cList = sd.cList; if (numClasses == 2) { for (int i = 0; i < y.length; i++) { int label = (int) sd.y[i]; y[i] = label == 0 ? 1 : 0; } List glms = buildBinaryClassifiers(sm.group, sm.indices, sm.values, y, maxNumIters, numLambdas, minLambdaRatio); for (GLM glm : glms) { double[] w = glm.coefficients(0); for (int j = 0; j < cList.length; j++) { int attIndex = sd.attrs[j]; w[attIndex] *= cList[j]; } } return glms; } else { boolean refit = this.refit; this.refit = false; // Not supported in multiclass // classification int p = attrs.length == 0 ? 0 : (StatUtils.max(attrs) + 1); List glms = new ArrayList<>(); for (int i = 0; i < numLambdas; i++) { GLM glm = new GLM(numClasses, p); glms.add(glm); } for (int k = 0; k < numClasses; k++) { // One-vs-the-rest for (int i = 0; i < y.length; i++) { int label = (int) sd.y[i]; y[i] = label == k ? 1 : 0; } List binaryClassifiers = buildBinaryClassifiers(group, indices, values, y, maxNumIters, numLambdas, minLambdaRatio); for (int l = 0; l < numLambdas; l++) { GLM binaryClassifier = binaryClassifiers.get(l); GLM glm = glms.get(l); double[] w = binaryClassifier.w[0]; for (int j = 0; j < cList.length; j++) { int attIndex = attrs[j]; glm.w[k][attIndex] = w[attIndex] * cList[j]; } glm.intercept[k] = binaryClassifier.intercept[0]; } } this.refit = refit; return glms; } } else { DenseDataset dd = getDenseDataset(trainSet, true); DenseDesignMatrix dm = createDesignMatrix(dd, groups); int[] attrs = dd.attrs; int[][] group = dm.groups; double[][][] x = dm.x; double[] y = new double[dd.y.length]; double[] cList = dd.cList; if (numClasses == 2) { for (int i = 0; i < y.length; i++) { int label = (int) dd.y[i]; y[i] = label == 0 ? 1 : 0; } List glms = buildBinaryClassifiers(dm.groups, dm.x, y, maxNumIters, numLambdas, minLambdaRatio); for (GLM glm : glms) { double[] w = glm.coefficients(0); for (int j = 0; j < cList.length; j++) { int attIndex = dd.attrs[j]; w[attIndex] *= cList[j]; } } return glms; } else { boolean refit = this.refit; this.refit = false; // Not supported in multiclass // classification int p = attrs.length == 0 ? 0 : (StatUtils.max(attrs) + 1); List glms = new ArrayList<>(); for (int i = 0; i < numLambdas; i++) { GLM glm = new GLM(numClasses, p); glms.add(glm); } for (int k = 0; k < numClasses; k++) { // One-vs-the-rest for (int i = 0; i < y.length; i++) { int label = (int) dd.y[i]; y[i] = label == k ? 1 : 0; } List binaryClassifiers = buildBinaryClassifiers(group, x, y, maxNumIters, numLambdas, minLambdaRatio); for (int l = 0; l < numLambdas; l++) { GLM binaryClassifier = binaryClassifiers.get(l); GLM glm = glms.get(l); double[] w = binaryClassifier.w[0]; for (int j = 0; j < cList.length; j++) { int attIndex = attrs[j]; glm.w[k][attIndex] = w[attIndex] * cList[j]; } glm.intercept[k] = binaryClassifier.intercept[0]; } } this.refit = refit; return glms; } } } /** * Builds group-lasso penalized classifiers. * * @param trainSet the training set. * @param groups the groups of variables. * @param maxNumIters the maximum number of iterations. * @param numLambdas the number of lambdas. * @param minLambdaRatio the minimum lambda is minLambdaRatio * max lambda. * @return group-lasso penalized classifiers. */ public List buildClassifiers(Instances trainSet, List groups, int maxNumIters, int numLambdas, double minLambdaRatio) { return buildClassifiers(trainSet, isSparse(trainSet), groups, maxNumIters, numLambdas, minLambdaRatio); } /** * Builds a group-lasso penalized regressor. * * @param trainSet the training set. * @param isSparse {@code true} if the training set is treated as sparse. * @param groups the groups of variables. * @param maxNumIters the maximum number of iterations. * @param lambda the lambda. * @return a group-lasso penalized regressor. */ public GLM buildGaussianRegressor(Instances trainSet, boolean isSparse, List groups, int maxNumIters, double lambda) { if (isSparse) { SparseDataset sd = getSparseDataset(trainSet, true); SparseDesignMatrix sm = createDesignMatrix(sd, groups); double[] cList = sd.cList; GLM glm = buildGaussianRegressor(sm.group, sm.indices, sm.values, sd.y, maxNumIters, lambda); double[] w = glm.coefficients(0); for (int j = 0; j < cList.length; j++) { int attIndex = sd.attrs[j]; w[attIndex] *= cList[j]; } return glm; } else { DenseDataset dd = getDenseDataset(trainSet, true); DenseDesignMatrix dm = createDesignMatrix(dd, groups); double[] cList = dd.cList; GLM glm = buildGaussianRegressor(dm.groups, dm.x, dd.y, maxNumIters, lambda); double[] w = glm.coefficients(0); for (int j = 0; j < cList.length; j++) { int attIndex = dd.attrs[j]; w[attIndex] *= cList[j]; } return glm; } } /** * Builds a group-lasso penalized regressor. * * @param trainSet the training set. * @param groups the groups. * @param maxNumIters the maximum number of iterations. * @param lambda the lambda. * @return a group-lasso penalized regressor. */ public GLM buildGaussianRegressor(Instances trainSet, List groups, int maxNumIters, double lambda) { return buildGaussianRegressor(trainSet, isSparse(trainSet), groups, maxNumIters, lambda); } /** * Builds a group-lasso penalized regressor. The input matrix is grouped by groups. This procedure does not * assume the data is normalized or centered. * * @param attrs the groups of variables. * @param x the inputs. * @param y the targets. * @param maxNumIters the maximum number of iterations. * @param lambda the lambda. * @return a group-lasso penalized regressor. */ public GLM buildGaussianRegressor(int[][] attrs, double[][][] x, double[] y, int maxNumIters, double lambda) { int p = 0; if (attrs.length > 0) { for (int[] attr : attrs) { p = Math.max(p, StatUtils.max(attr)); } p += 1; } // Backup targets double[] rTrain = new double[y.length]; for (int i = 0; i < rTrain.length; i++) { rTrain[i] = y[i]; } double[][] w = new double[attrs.length][]; double[] tl1 = new double[attrs.length]; int m = 0; for (int j = 0; j < attrs.length; j++) { w[j] = new double[x[j].length]; tl1[j] = lambda * Math.sqrt(w[j].length); if (w[j].length > m) { m = w[j].length; } } double[] stepSize = new double[attrs.length]; for (int j = 0; j < stepSize.length; j++) { double max = 0; double[][] block = x[j]; for (double[] t : block) { double l = StatUtils.sumSq(t); if (l > max) { max = l; } } stepSize[j] = 1.0 / max; } double[] g = new double[m]; double[] gradient = new double[m]; boolean[] activeSet = new boolean[x.length]; double intercept = 0; // Block coordinate gradient descent int iter = 0; while (iter < maxNumIters) { if (fitIntercept) { intercept += OptimUtils.fitIntercept(rTrain); } boolean activeSetChanged = doOnePassGaussian(x, tl1, true, activeSet, w, stepSize, g, gradient, rTrain); iter++; if (!activeSetChanged || iter > maxNumIters) { break; } for (; iter < maxNumIters; iter++) { double prevLoss = GLMOptimUtils.computeGroupLassoLoss(rTrain, w, tl1); if (fitIntercept) { intercept += OptimUtils.fitIntercept(rTrain); } doOnePassGaussian(x, tl1, false, activeSet, w, stepSize, g, gradient, rTrain); double currLoss = GLMOptimUtils.computeGroupLassoLoss(rTrain, w, tl1); if (OptimUtils.isConverged(prevLoss, currLoss, epsilon)) { break; } if (verbose) { System.out.println("Iteration " + iter + ": " + currLoss); } } } if (refit) { boolean[] selected = new boolean[attrs.length]; for (int i = 0; i < selected.length; i++) { selected[i] = !ArrayUtils.isConstant(w[i], 0, w[i].length, 0); } return refitGaussianRegressor(p, attrs, selected, x, y, w, maxNumIters); } else { return getGLM(p, attrs, w, intercept, LinkFunction.IDENTITY); } } /** * Builds a group-lasso penalized regressor on sparse inputs. The input matrix is grouped by groups. This procedure does not * assume the data is normalized or centered. * * @param groups the groups of variables. * @param indices the indices. * @param values the values. * @param y the targets. * @param maxNumIters the maximum number of iterations. * @param lambda the lambda. * @return a group-lasso penalized regressor. */ public GLM buildGaussianRegressor(int[][] groups, int[][][] indices, double[][][] values, double[] y, int maxNumIters, double lambda) { int p = 0; if (groups.length > 0) { for (int[] group : groups) { p = Math.max(p, StatUtils.max(group)); } p += 1; } // Backup targets double[] rTrain = new double[y.length]; for (int i = 0; i < rTrain.length; i++) { rTrain[i] = y[i]; } double[][] w = new double[groups.length][]; double[] tl1 = new double[groups.length]; int m = 0; for (int j = 0; j < groups.length; j++) { w[j] = new double[groups[j].length]; tl1[j] = lambda * Math.sqrt(w[j].length); if (w[j].length > m) { m = w[j].length; } } double[] stepSize = new double[groups.length]; for (int j = 0; j < stepSize.length; j++) { double max = 0; double[][] block = values[j]; for (double[] t : block) { double l = StatUtils.sumSq(t); if (l > max) { max = l; } } stepSize[j] = 1.0 / max; } double[] g = new double[m]; double[] gradient = new double[m]; boolean[] activeSet = new boolean[groups.length]; double intercept = 0; int iter = 0; while (iter < maxNumIters) { if (fitIntercept) { intercept += OptimUtils.fitIntercept(rTrain); } boolean activeSetChanged = doOnePassGaussian(indices, values, tl1, true, activeSet, w, stepSize, g, gradient, rTrain); iter++; if (!activeSetChanged || iter > maxNumIters) { break; } for (; iter < maxNumIters; iter++) { double prevLoss = GLMOptimUtils.computeGroupLassoLoss(rTrain, w, tl1); if (fitIntercept) { intercept += OptimUtils.fitIntercept(rTrain); } doOnePassGaussian(indices, values, tl1, false, activeSet, w, stepSize, g, gradient, rTrain); double currLoss = GLMOptimUtils.computeGroupLassoLoss(rTrain, w, tl1); if (OptimUtils.isConverged(prevLoss, currLoss, epsilon)) { break; } if (verbose) { System.out.println("Iteration " + iter + ": " + currLoss); } } } if (refit) { boolean[] selected = new boolean[groups.length]; for (int i = 0; i < selected.length; i++) { selected[i] = !ArrayUtils.isConstant(w[i], 0, w[i].length, 0); } return refitGaussianRegressor(p, groups, selected, indices, values, y, w, maxNumIters); } else { return getGLM(p, groups, w, intercept, LinkFunction.IDENTITY); } } /** * Builds group-lasso penalized regressors. * * @param trainSet the training set. * @param isSparse {@code true} if the training set is treated as sparse. * @param groups the groups of variables. * @param maxNumIters the maximum number of iterations. * @param numLambdas the number of lambdas. * @param minLambdaRatio the minimum lambda is minLambdaRatio * max lambda. * @return group-lasso penalized regressors. */ public List buildGaussianRegressors(Instances trainSet, boolean isSparse, List groups, int maxNumIters, int numLambdas, double minLambdaRatio) { if (isSparse) { SparseDataset sd = getSparseDataset(trainSet, true); SparseDesignMatrix sm = createDesignMatrix(sd, groups); double[] cList = sd.cList; List glms = buildGaussianRegressors(sm.group, sm.indices, sm.values, sd.y, maxNumIters, numLambdas, minLambdaRatio); for (GLM glm : glms) { double[] w = glm.coefficients(0); for (int j = 0; j < cList.length; j++) { int attIndex = sd.attrs[j]; w[attIndex] *= cList[j]; } } return glms; } else { DenseDataset dd = getDenseDataset(trainSet, false); DenseDesignMatrix dm = createDesignMatrix(dd, groups); double[] stdList = dd.stdList; List glms = buildGaussianRegressors(dm.groups, dm.x, dd.y, maxNumIters, numLambdas, minLambdaRatio); for (GLM glm : glms) { double[] w = glm.coefficients(0); for (int j = 0; j < stdList.length; j++) { int attIndex = dd.attrs[j]; w[attIndex] *= stdList[j]; } } return glms; } } /** * Builds group-lasso penalized regressors. * * @param trainSet the training set. * @param groups the groups of variables. * @param maxNumIters the maximum number of iterations. * @param numLambdas the number of lambdas. * @param minLambdaRatio the minimum lambda is minLambdaRatio * max lambda. * @return group-lasso penalized regressors. */ public List buildGaussianRegressors(Instances trainSet, List groups, int maxNumIters, int numLambdas, double minLambdaRatio) { return buildGaussianRegressors(trainSet, isSparse(trainSet), groups, maxNumIters, numLambdas, minLambdaRatio); } /** * Builds group-lasso penalized regressors for a sequence of regularization parameter lambdas. The input matrix is grouped by groups. This procedure does not * assume the data is normalized or centered. * * @param groups the groups of variables. * @param x the inputs. * @param y the targets. * @param maxNumIters the maximum number of iterations. * @param numLambdas the number of lambdas. * @param minLambdaRatio the minimum lambda is minLambdaRatio * max lambda. * @return group-lasso penalized regressors. */ public List buildGaussianRegressors(int[][] groups, double[][][] x, double[] y, int maxNumIters, int numLambdas, double minLambdaRatio) { int p = 0; if (groups.length > 0) { for (int[] group : groups) { p = Math.max(p, StatUtils.max(group)); } p += 1; } // Backup targets double[] rTrain = new double[y.length]; for (int i = 0; i < rTrain.length; i++) { rTrain[i] = y[i]; } // Allocate coefficients double[][] w = new double[x.length][]; int m = 0; for (int j = 0; j < x.length; j++) { w[j] = new double[x[j].length]; if (w[j].length > m) { m = w[j].length; } } double[] g = new double[m]; double[] gradient = new double[m]; double[] tl1 = new double[x.length]; double[] stepSize = new double[x.length]; for (int j = 0; j < stepSize.length; j++) { double max = 0; double[][] block = x[j]; for (double[] t : block) { double l = StatUtils.sumSq(t); if (l > max) { max = l; } } stepSize[j] = 1.0 / max; } boolean[] activeSet = new boolean[x.length]; // Determine max lambda double maxLambda = findMaxLambdaGaussian(x, rTrain, gradient); // Dampening factor for lambda double alpha = Math.pow(minLambdaRatio, 1.0 / numLambdas); // Compute the regularization path List glms = new ArrayList<>(numLambdas); double lambda = maxLambda; double intercept = 0; Set structures = new HashSet<>(); for (int i = 0; i < numLambdas; i++) { // Initialize regularization parameters for (int j = 0; j < tl1.length; j++) { tl1[j] = lambda * Math.sqrt(w[j].length); } // Block coordinate gradient descent int iter = 0; while (iter < maxNumIters) { if (fitIntercept) { intercept += OptimUtils.fitIntercept(rTrain); } boolean activeSetChanged = doOnePassGaussian(x, tl1, true, activeSet, w, stepSize, g, gradient, rTrain); iter++; if (!activeSetChanged || iter > maxNumIters) { break; } for (; iter < maxNumIters; iter++) { double prevLoss = GLMOptimUtils.computeGroupLassoLoss(rTrain, w, tl1); if (fitIntercept) { intercept += OptimUtils.fitIntercept(rTrain); } doOnePassGaussian(x, tl1, false, activeSet, w, stepSize, g, gradient, rTrain); double currLoss = GLMOptimUtils.computeGroupLassoLoss(rTrain, w, tl1); if (OptimUtils.isConverged(prevLoss, currLoss, epsilon)) { break; } if (verbose) { System.out.println("Iteration " + iter + ": " + currLoss); } } } lambda *= alpha; if (refit) { boolean allActivated = true; boolean[] selected = new boolean[groups.length]; for (int j = 0; j < selected.length; j++) { selected[j] = !ArrayUtils.isConstant(w[j], 0, w[j].length, 0); allActivated = allActivated & selected[j]; } ModelStructure structure = new ModelStructure(selected); if (!structures.contains(structure)) { GLM glm = refitGaussianRegressor(p, groups, selected, x, y, w, maxNumIters); glms.add(glm); structures.add(structure); } if (allActivated) { break; } } else { GLM glm = getGLM(p, groups, w, intercept, LinkFunction.IDENTITY); glms.add(glm); } } return glms; } /** * Builds group-lasso penalized regressors on sparse inputs for a sequence of regularization parameter lambdas. The input matrix is grouped by groups. This procedure does not * assume the data is normalized or centered. * * @param groups the groups of variables. * @param indices the indices. * @param values the values. * @param y the targets. * @param maxNumIters the maximum number of iterations. * @param numLambdas the number of lambdas. * @param minLambdaRatio the minimum lambda is minLambdaRatio * max lambda. * @return group-lasso penalized regressors. */ public List buildGaussianRegressors(int[][] groups, int[][][] indices, double[][][] values, double[] y, int maxNumIters, int numLambdas, double minLambdaRatio) { int p = 0; if (groups.length > 0) { for (int[] group : groups) { p = Math.max(p, StatUtils.max(group)); } p += 1; } // Backup targets double[] rTrain = new double[y.length]; for (int i = 0; i < rTrain.length; i++) { rTrain[i] = y[i]; } // Allocate coefficients double[][] w = new double[groups.length][]; int m = 0; for (int j = 0; j < groups.length; j++) { w[j] = new double[groups[j].length]; if (w[j].length > m) { m = w[j].length; } } double[] g = new double[m]; double[] gradient = new double[m]; double[] tl1 = new double[groups.length]; double[] stepSize = new double[groups.length]; for (int j = 0; j < stepSize.length; j++) { double max = 0; double[][] block = values[j]; for (double[] t : block) { double l = StatUtils.sumSq(t); if (l > max) { max = l; } } stepSize[j] = 1.0 / max; } boolean[] activeSet = new boolean[groups.length]; // Determine max lambda double maxLambda = findMaxLambdaGaussian(indices, values, rTrain, gradient); // Dampening factor for lambda double alpha = Math.pow(minLambdaRatio, 1.0 / numLambdas); // Compute the regularization path List glms = new ArrayList<>(numLambdas); double lambda = maxLambda; double intercept = 0; Set structures = new HashSet<>(); for (int i = 0; i < numLambdas; i++) { // Initialize regularization parameters for (int j = 0; j < tl1.length; j++) { tl1[j] = lambda * Math.sqrt(w[j].length); } // Block coordinate gradient descent int iter = 0; while (iter < maxNumIters) { if (fitIntercept) { intercept += OptimUtils.fitIntercept(rTrain); } boolean activeSetChanged = doOnePassGaussian(indices, values, tl1, true, activeSet, w, stepSize, g, gradient, rTrain); iter++; if (!activeSetChanged || iter > maxNumIters) { break; } for (; iter < maxNumIters; iter++) { double prevLoss = GLMOptimUtils.computeGroupLassoLoss(rTrain, w, tl1); if (fitIntercept) { intercept += OptimUtils.fitIntercept(rTrain); } doOnePassGaussian(indices, values, tl1, false, activeSet, w, stepSize, g, gradient, rTrain); double currLoss = GLMOptimUtils.computeGroupLassoLoss(rTrain, w, tl1); if (OptimUtils.isConverged(prevLoss, currLoss, epsilon)) { break; } if (verbose) { System.out.println("Iteration " + iter + ": " + currLoss); } } } lambda *= alpha; if (refit) { boolean allActivated = true; boolean[] selected = new boolean[groups.length]; for (int j = 0; j < selected.length; j++) { selected[j] = !ArrayUtils.isConstant(w[j], 0, w[j].length, 0); allActivated = allActivated & selected[j]; } ModelStructure structure = new ModelStructure(selected); if (!structures.contains(structure)) { GLM glm = refitGaussianRegressor(p, groups, selected, indices, values, y, w, maxNumIters); glms.add(glm); structures.add(structure); } if (allActivated) { break; } } else { GLM glm = getGLM(p, groups, w, intercept, LinkFunction.IDENTITY); glms.add(glm); } } return glms; } /** * Returns the lambda. * * @return the lambda. */ public double getLambda() { return lambda; } /** * Returns the number of lambdas. * * @return the number of lambdas. */ public int getNumLambdas() { return numLambdas; } /** * Returns the task of this learner. * * @return the task of this learner. */ public Task getTask() { return task; } /** * Returns {@code true} if we refit the model. * * @return {@code true} if we refit the model. */ public boolean refit() { return refit; } /** * Sets whether we refit the model. * * @param refit {@code true} if we refit the model. */ public void refit(boolean refit) { this.refit = refit; } /** * Sets the lambda. * * @param lambda the lambda. */ public void setLambda(double lambda) { this.lambda = lambda; } /** * Sets the number of lambdas. * * @param numLambdas the number of lambdas. */ public void setNumLambdas(int numLambdas) { this.numLambdas = numLambdas; } /** * Sets the task of this learner. * * @param task the task of this learner. */ public void setTask(Task task) { this.task = task; } /** * Returns the groups. * * @return the groups. */ public List getGroups() { return groups; } /** * Sets the groups. * * @param groups the groups. */ public void setGroups(List groups) { this.groups = groups; } protected void computeGradient(double[][] block, double[] rTrain, double[] gradient) { for (int i = 0; i < block.length; i++) { gradient[i] = VectorUtils.dotProduct(block[i], rTrain) / rTrain.length; } } protected void computeGradient(int[][] index, double[][] block, double[] rTrain, double[] gradient) { for (int j = 0; j < block.length; j++) { double[] t = block[j]; int[] idx = index[j]; gradient[j] = 0; for (int i = 0; i < t.length; i++) { gradient[j] += rTrain[idx[i]] * t[i]; } gradient[j] /= rTrain.length; } } protected double computePenalty(double[] w, double lambda) { return lambda * VectorUtils.l2norm(w); } protected double computePenalty(double[][] w, double[] lambdas) { double penalty = 0; for (int i = 0; i < w.length; i++) { penalty += computePenalty(w[i], lambdas[i]); } return penalty; } protected DenseDesignMatrix createDesignMatrix(DenseDataset dd, List groupList) { int[] attrs = dd.attrs; Map attrSet = new HashMap<>(); for (int j = 0; j < attrs.length; j++) { attrSet.put(attrs[j], j); } List gList = new ArrayList<>(); for (int[] group : groupList) { List list = new ArrayList<>(); for (int idx : group) { if (attrSet.containsKey(idx)) { list.add(attrSet.get(idx)); } } if (list.size() > 0) { int[] a = new int[list.size()]; for (int i = 0; i < a.length; i++) { a[i] = list.get(i); } gList.add(a); } } int[][] groups = new int[gList.size()][]; double[][][] x = new double[gList.size()][][]; for (int g = 0; g < groups.length; g++) { int[] group = gList.get(g); groups[g] = group; double[][] v = new double[group.length][]; for (int j = 0; j < group.length; j++) { int idx = group[j]; v[j] = dd.x[idx]; } x[g] = v; } return new DenseDesignMatrix(groups, x); } protected SparseDesignMatrix createDesignMatrix(SparseDataset sd, List groupList) { int[] attrs = sd.attrs; Map attrSet = new HashMap<>(); for (int j = 0; j < attrs.length; j++) { attrSet.put(attrs[j], j); } List gList = new ArrayList<>(); for (int[] group : groupList) { List list = new ArrayList<>(); for (int idx : group) { if (attrSet.containsKey(idx)) { list.add(attrSet.get(idx)); } } if (list.size() > 0) { int[] a = new int[list.size()]; for (int i = 0; i < a.length; i++) { a[i] = list.get(i); } gList.add(a); } } int[][] groups = new int[gList.size()][]; int[][][] indices = new int[gList.size()][][]; double[][][] values = new double[gList.size()][][]; for (int g = 0; g < groups.length; g++) { int[] group = gList.get(g); groups[g] = group; int[][] idx = new int[group.length][]; double[][] val = new double[group.length][]; for (int j = 0; j < group.length; j++) { int index = group[j]; idx[j] = sd.indices[index]; val[j] = sd.values[index]; } indices[g] = idx; values[g] = val; } return new SparseDesignMatrix(groups, indices, values); } protected boolean doOnePassGaussian(double[][][] x, double[] tl1, boolean isFullPass, boolean[] activeSet, double[][] w, double[] stepSize, double[] g, double[] gradient, double[] rTrain) { boolean activeSetChanged = false; for (int k = 0; k < x.length; k++) { if (!isFullPass && !activeSet[k]) { continue; } double[][] block = x[k]; double[] beta = w[k]; double tk = stepSize[k]; // Proximal gradient method computeGradient(block, rTrain, gradient); for (int j = 0; j < beta.length; j++) { g[j] = beta[j] + tk * gradient[j]; } double norm = Math.sqrt(StatUtils.sumSq(g, 0, beta.length)); double lambda = tl1[k] * tk; if (norm > lambda) { VectorUtils.multiply(g, (1 - lambda / norm)); } else { Arrays.fill(g, 0); } // Update residuals for (int j = 0; j < beta.length; j++) { double[] t = block[j]; double delta = beta[j] - g[j]; for (int i = 0; i < rTrain.length; i++) { rTrain[i] += delta * t[i]; } } // Update weights for (int j = 0; j < beta.length; j++) { beta[j] = g[j]; } if (isFullPass && !activeSet[k] && !ArrayUtils.isConstant(beta, 0, beta.length, 0)) { activeSetChanged = true; activeSet[k] = true; } } return activeSetChanged; } protected boolean doOnePassGaussian(int[][][] indices, double[][][] values, double[] tl1, boolean isFullPass, boolean[] activeSet, double[][] w, double[] stepSize, double[] g, double[] gradient, double[] rTrain) { boolean activeSetChanged = false; for (int k = 0; k < values.length; k++) { if (!isFullPass && !activeSet[k]) { continue; } double[][] block = values[k]; int[][] index = indices[k]; double[] beta = w[k]; double tk = tl1[k]; // Proximal gradient method computeGradient(index, block, rTrain, gradient); for (int j = 0; j < beta.length; j++) { g[j] = beta[j] + tk * gradient[j]; } double norm = Math.sqrt(StatUtils.sumSq(g, 0, beta.length)); double lambda = tl1[k] * tk; if (norm > lambda) { VectorUtils.multiply(g, (1 - lambda / norm)); } else { VectorUtils.multiply(g, 0); } // Update predictions for (int j = 0; j < beta.length; j++) { int[] idx = index[j]; double[] t = block[j]; double delta = beta[j] - g[j]; for (int i = 0; i < t.length; i++) { rTrain[idx[i]] += delta * t[i]; } } // Update weights for (int j = 0; j < beta.length; j++) { beta[j] = g[j]; } if (isFullPass && !activeSet[k] && !ArrayUtils.isConstant(beta, 0, beta.length, 0)) { activeSetChanged = true; activeSet[k] = true; } } return activeSetChanged; } protected boolean doOnePassBinomial(double[][][] x, double[] y, double[] tl1, boolean isFullPass, boolean[] activeSet, double[][] w, double[] stepSize, double[] g, double[] gradient, double[] pTrain, double[] rTrain) { boolean activeSetChanged = false; for (int k = 0; k < x.length; k++) { if (!isFullPass && !activeSet[k]) { continue; } double[][] block = x[k]; double[] beta = w[k]; double tk = stepSize[k]; // Proximal gradient method computeGradient(block, rTrain, gradient); for (int j = 0; j < beta.length; j++) { g[j] = beta[j] + tk * gradient[j]; } double norm = Math.sqrt(StatUtils.sumSq(g, 0, beta.length)); double lambda = tl1[k] * tk; if (norm > lambda) { VectorUtils.multiply(g, (1 - lambda / norm)); } else { Arrays.fill(g, 0); } // Update predictions for (int j = 0; j < beta.length; j++) { double[] t = block[j]; double delta = g[j] - beta[j]; for (int i = 0; i < y.length; i++) { pTrain[i] += delta * t[i]; } } OptimUtils.computePseudoResidual(pTrain, y, rTrain); // Update weights for (int j = 0; j < beta.length; j++) { beta[j] = g[j]; } if (isFullPass && !activeSet[k] && !ArrayUtils.isConstant(beta, 0, beta.length, 0)) { activeSetChanged = true; activeSet[k] = true; } } return activeSetChanged; } protected boolean doOnePassBinomial(int[][][] indices, int[][] indexUnion, double[][][] values, double[] y, double[] tl1, boolean isFullPass, boolean[] activeSet, double[][] w, double[] stepSize, double[] g, double[] gradient, double[] pTrain, double[] rTrain) { boolean activeSetChanged = false; for (int k = 0; k < values.length; k++) { if (!isFullPass && !activeSet[k]) { continue; } int[][] index = indices[k]; double[][] block = values[k]; double[] beta = w[k]; double tk = stepSize[k]; // Proximal gradient method computeGradient(index, block, rTrain, gradient); for (int j = 0; j < beta.length; j++) { g[j] = beta[j] + tk * gradient[j]; } double norm = Math.sqrt(StatUtils.sumSq(g, 0, beta.length)); double lambda = tl1[k] * tk; if (norm > lambda) { VectorUtils.multiply(g, (1 - lambda / norm)); } else { VectorUtils.multiply(g, 0); } // Update predictions for (int j = 0; j < beta.length; j++) { int[] idx = index[j]; double[] value = block[j]; double delta = g[j] - beta[j]; for (int i = 0; i < value.length; i++) { pTrain[idx[i]] += delta * value[i]; } } int[] idxUnion = indexUnion[k]; for (int idx : idxUnion) { rTrain[idx] = OptimUtils.getPseudoResidual(pTrain[idx], y[idx]); } // Update weights for (int j = 0; j < beta.length; j++) { beta[j] = g[j]; } if (isFullPass && !activeSet[k] && !ArrayUtils.isConstant(beta, 0, beta.length, 0)) { activeSetChanged = true; activeSet[k] = true; } } return activeSetChanged; } protected double findMaxLambdaGaussian(double[][][] x, double[] rTrain, double[] gradient) { double mean = 0; if (fitIntercept) { mean = OptimUtils.fitIntercept(rTrain); } double maxLambda = 0; for (double[][] block : x) { computeGradient(block, rTrain, gradient); double t = Math.sqrt(StatUtils.sumSq(gradient, 0, block.length)) / Math.sqrt(block.length); if (t > maxLambda) { maxLambda = t; } } if (fitIntercept) { VectorUtils.add(rTrain, mean); } return maxLambda; } protected double findMaxLambdaGaussian(int[][][] indices, double[][][] values, double[] rTrain, double[] gradient) { double mean = 0; if (fitIntercept) { mean = OptimUtils.fitIntercept(rTrain); } double maxLambda = 0; for (int g = 0; g < values.length; g++) { int[][] index = indices[g]; double[][] block = values[g]; computeGradient(index, block, rTrain, gradient); double t = Math.sqrt(StatUtils.sumSq(gradient, 0, block.length)) / Math.sqrt(block.length); if (t > maxLambda) { maxLambda = t; } } if (fitIntercept) { VectorUtils.add(rTrain, mean); } return maxLambda; } protected double findMaxLambdaBinomial(double[][][] x, double[] y, double[] pTrain, double[] rTrain, double[] gradient) { if (fitIntercept) { OptimUtils.fitIntercept(pTrain, rTrain, y); } double maxLambda = 0; for (double[][] block : x) { computeGradient(block, rTrain, gradient); double t = Math.sqrt(StatUtils.sumSq(gradient, 0, block.length)) / Math.sqrt(block.length); if (t > maxLambda) { maxLambda = t; } } if (fitIntercept) { Arrays.fill(pTrain, 0); OptimUtils.computePseudoResidual(pTrain, y, rTrain); } return maxLambda; } protected double findMaxLambdaBinomial(int[][][] indices, double[][][] values, double[] y, double[] pTrain, double[] rTrain, double[] gradient) { if (fitIntercept) { OptimUtils.fitIntercept(pTrain, rTrain, y); } double maxLambda = 0; for (int j = 0; j < values.length; j++) { int[][] index = indices[j]; double[][] block = values[j]; computeGradient(index, block, rTrain, gradient); double t = Math.sqrt(StatUtils.sumSq(gradient, 0, block.length)) / Math.sqrt(block.length); if (t > maxLambda) { maxLambda = t; } } if (fitIntercept) { Arrays.fill(pTrain, 0); OptimUtils.computePseudoResidual(pTrain, y, rTrain); } return maxLambda; } protected GLM getGLM(int p, int[][] attrs, boolean[] selected, double[] coef, double intercept, LinkFunction link) { GLM glm = new GLM(p); int k = 0; double[] w = glm.coefficients(0); for (int g = 0; g < attrs.length; g++) { if (selected[g]) { int[] attr = attrs[g]; for (int attIndex : attr) { w[attIndex] = coef[k++]; } } } glm.intercept[0] = intercept; glm.link = link; return glm; } protected GLM getGLM(int p, int[][] attrs, double[][] coef, double intercept, LinkFunction link) { GLM glm = new GLM(p); double[] w = glm.coefficients(0); for (int g = 0; g < attrs.length; g++) { int[] attr = attrs[g]; double[] beta = coef[g]; for (int j = 0; j < attr.length; j++) { w[attr[j]] = beta[j]; } } glm.intercept[0] = intercept; glm.link = link; return glm; } protected GLM refitClassifier(int p, int[][] groups, boolean[] selected, double[][][] x, double[] y, double[][] w, int maxNumIters) { List xList = new ArrayList<>(); for (int g = 0; g < selected.length; g++) { if (selected[g]) { double[][] t = x[g]; for (int j = 0; j < t.length; j++) { xList.add(t[j]); } } } double[][] xNew = new double[xList.size()][]; for (int i = 0; i < xNew.length; i++) { xNew[i] = xList.get(i); } int[] attrs = new int[xNew.length]; for (int i = 0; i < attrs.length; i++) { attrs[i] = i; } RidgeLearner ridgeLearner = new RidgeLearner(); ridgeLearner.setVerbose(verbose); ridgeLearner.setEpsilon(epsilon); ridgeLearner.fitIntercept(fitIntercept); // A ridge regression with very small regularization parameter // This often improves stability a lot GLM glm = ridgeLearner.buildBinaryClassifier(attrs, xNew, y, maxNumIters, 1e-8); return getGLM(p, groups, selected, glm.coefficients(0), glm.intercept(0), LinkFunction.LOGIT); } protected GLM refitClassifier(int p, int[][] groups, boolean[] selected, int[][][] indices, double[][][] values, double[] y, double[][] w, int maxNumIters) { List iList = new ArrayList<>(); List vList = new ArrayList<>(); for (int g = 0; g < selected.length; g++) { if (selected[g]) { int[][] iBlock = indices[g]; double[][] vBlock = values[g]; for (int j = 0; j < vBlock.length; j++) { iList.add(iBlock[j]); vList.add(vBlock[j]); } } } int[][] idxNew = new int[iList.size()][]; for (int i = 0; i < idxNew.length; i++) { idxNew[i] = iList.get(i); } double[][] valNew = new double[vList.size()][]; for (int i = 0; i < valNew.length; i++) { valNew[i] = vList.get(i); } int[] attrs = new int[valNew.length]; for (int i = 0; i < attrs.length; i++) { attrs[i] = i; } RidgeLearner ridgeLearner = new RidgeLearner(); ridgeLearner.setVerbose(verbose); ridgeLearner.setEpsilon(epsilon); ridgeLearner.fitIntercept(fitIntercept); // A ridge regression with very small regularization parameter // This often improves stability a lot GLM glm = ridgeLearner.buildBinaryClassifier(attrs, idxNew, valNew, y, maxNumIters, 1e-8); return getGLM(p, groups, selected, glm.coefficients(0), glm.intercept(0), LinkFunction.LOGIT); } protected GLM refitGaussianRegressor(int p, int[][] attrs, boolean[] selected, double[][][] x, double[] y, double[][] w, int maxNumIters) { List xList = new ArrayList<>(); for (int g = 0; g < selected.length; g++) { if (selected[g]) { double[][] t = x[g]; for (int j = 0; j < t.length; j++) { xList.add(t[j]); } } } if (xList.size() == 0) { if (fitIntercept) { double intercept = StatUtils.mean(y); GLM glm = new GLM(); glm.intercept[0] = intercept; return glm; } else { return new GLM(); } } double[][] xNew = new double[xList.size()][]; for (int i = 0; i < xNew.length; i++) { xNew[i] = xList.get(i); } int[] attrsNew = new int[xNew.length]; for (int j = 0; j < attrsNew.length; j++) { attrsNew[j] = j; } RidgeLearner ridgeLearner = new RidgeLearner(); ridgeLearner.setVerbose(verbose); ridgeLearner.setEpsilon(epsilon); ridgeLearner.fitIntercept(fitIntercept); // A ridge regression with very small regularization parameter // This often improves stability a lot GLM glm = ridgeLearner.buildGaussianRegressor(attrsNew, xNew, y, maxNumIters, 1e-8); return getGLM(p, attrs, selected, glm.coefficients(0), glm.intercept(0), LinkFunction.IDENTITY); } protected GLM refitGaussianRegressor(int p, int[][] groups, boolean[] selected, int[][][] indices, double[][][] values, double[] y, double[][] w, int maxNumIters) { List iList = new ArrayList<>(); List vList = new ArrayList<>(); for (int g = 0; g < selected.length; g++) { if (selected[g]) { int[][] iBlock = indices[g]; double[][] vBlock = values[g]; for (int j = 0; j < vBlock.length; j++) { iList.add(iBlock[j]); vList.add(vBlock[j]); } } } if (vList.size() == 0) { if (fitIntercept) { double intercept = StatUtils.mean(y); GLM glm = new GLM(); glm.intercept[0] = intercept; return glm; } else { return new GLM(); } } int[][] idxNew = new int[iList.size()][]; for (int i = 0; i < idxNew.length; i++) { idxNew[i] = iList.get(i); } double[][] valNew = new double[vList.size()][]; for (int i = 0; i < valNew.length; i++) { valNew[i] = vList.get(i); } int[] attrs = new int[valNew.length]; for (int i = 0; i < attrs.length; i++) { attrs[i] = i; } RidgeLearner ridgeLearner = new RidgeLearner(); ridgeLearner.setVerbose(verbose); ridgeLearner.setEpsilon(epsilon); ridgeLearner.fitIntercept(fitIntercept); // A ridge regression with very small regularization parameter // This often improves stability a lot GLM glm = ridgeLearner.buildGaussianRegressor(attrs, idxNew, valNew, y, maxNumIters, 1e-8); return getGLM(p, groups, selected, glm.coefficients(0), glm.intercept(0), LinkFunction.IDENTITY); } } ================================================ FILE: src/main/java/mltk/predictor/glm/LassoLearner.java ================================================ package mltk.predictor.glm; import java.util.ArrayList; import java.util.Arrays; import java.util.HashSet; import java.util.List; import java.util.Set; import mltk.cmdline.Argument; import mltk.cmdline.CmdLineParser; import mltk.cmdline.options.LearnerWithTaskOptions; import mltk.core.Attribute; import mltk.core.DenseVector; import mltk.core.Instances; import mltk.core.NominalAttribute; import mltk.core.SparseVector; import mltk.core.io.InstancesReader; import mltk.predictor.Family; import mltk.predictor.LinkFunction; import mltk.predictor.io.PredictorWriter; import mltk.util.MathUtils; import mltk.util.OptimUtils; import mltk.util.StatUtils; import mltk.util.VectorUtils; /** * Class for learning L1-regularized linear model via coordinate descent. * * @author Yin Lou * */ public class LassoLearner extends GLMLearner { static class Options extends LearnerWithTaskOptions { @Argument(name = "-m", description = "maximum num of iterations (default: 0)") int maxIter = 0; @Argument(name = "-l", description = "lambda (default: 0)") double lambda = 0; } /** * Trains L1-regularized GLMs. * *
	 * Usage: mltk.predictor.glm.LassoLearner
	 * -t	train set path
	 * [-g]	task between classification (c) and regression (r) (default: r)
	 * [-r]	attribute file path
	 * [-o]	output model path
	 * [-V]	verbose (default: true)
	 * [-m]	maximum num of iterations (default: 0)
	 * [-l]	lambda (default: 0)
	 * 
* * @param args the command line arguments. * @throws Exception */ public static void main(String[] args) throws Exception { Options opts = new Options(); CmdLineParser parser = new CmdLineParser(LassoLearner.class, opts); Task task = null; try { parser.parse(args); task = Task.get(opts.task); } catch (IllegalArgumentException e) { parser.printUsage(); System.exit(1); } Instances trainSet = InstancesReader.read(opts.attPath, opts.trainPath); LassoLearner learner = new LassoLearner(); learner.setVerbose(opts.verbose); learner.setTask(task); learner.setLambda(opts.lambda); learner.setMaxNumIters(opts.maxIter); long start = System.currentTimeMillis(); GLM glm = learner.build(trainSet); long end = System.currentTimeMillis(); System.out.println("Time: " + (end - start) / 1000.0); if (opts.outputModelPath != null) { PredictorWriter.write(glm, opts.outputModelPath); } } static class ModelStructure { boolean[] structure; ModelStructure(boolean[] structure) { this.structure = structure; } @Override public boolean equals(Object obj) { if (this == obj) return true; if (obj == null) return false; if (getClass() != obj.getClass()) return false; ModelStructure other = (ModelStructure) obj; if (!Arrays.equals(structure, other.structure)) return false; return true; } @Override public int hashCode() { final int prime = 31; int result = 1; result = prime * result + Arrays.hashCode(structure); return result; } } protected boolean refit; protected int numLambdas; protected double lambda; protected Task task; /** * Constructor. */ public LassoLearner() { refit = false; lambda = 0; // no regularization numLambdas = -1; // no regularization path task = Task.REGRESSION; } @Override public GLM build(Instances instances) { GLM glm = null; if (maxNumIters < 0) { maxNumIters = 20; } switch (task) { case REGRESSION: glm = buildGaussianRegressor(instances, maxNumIters, lambda); break; case CLASSIFICATION: glm = buildClassifier(instances, maxNumIters, lambda); break; default: break; } return glm; } @Override public GLM build(Instances trainSet, Family family) { GLM glm = null; if (maxNumIters < 0) { maxNumIters = 20; } switch (family) { case GAUSSIAN: glm = buildGaussianRegressor(trainSet, maxNumIters, lambda); break; case BINOMIAL: glm = buildClassifier(trainSet, maxNumIters, lambda); break; default: throw new IllegalArgumentException("Unsupported family: " + family); } return glm; } /** * Builds an L1-regularized binary classifier. Each row in the input matrix x represents a feature (instead of a * data point). Thus the input matrix is the transpose of the row-oriented data matrix. This procedure does not * assume the data is normalized or centered. * * @param attrs the attribute list. * @param x the inputs. * @param y the targets. * @param maxNumIters the maximum number of iterations. * @param lambda the lambda. * @return an L1-regularized classifier. */ public GLM buildBinaryClassifier(int[] attrs, double[][] x, double[] y, int maxNumIters, double lambda) { double[] w = new double[attrs.length]; double intercept = 0; double[] pTrain = new double[y.length]; double[] rTrain = new double[y.length]; OptimUtils.computePseudoResidual(pTrain, y, rTrain); // Calculate theta's double[] theta = new double[x.length]; for (int i = 0; i < x.length; i++) { theta[i] = StatUtils.sumSq(x[i]) / 4; } // Coordinate descent final double tl1 = lambda * y.length; for (int iter = 0; iter < maxNumIters; iter++) { double prevLoss = GLMOptimUtils.computeLassoLoss(pTrain, y, w, lambda); if (fitIntercept) { intercept += OptimUtils.fitIntercept(pTrain, rTrain, y); } doOnePassBinomial(x, theta, y, tl1, w, pTrain, rTrain); double currLoss = GLMOptimUtils.computeLassoLoss(pTrain, y, w, lambda); if (verbose) { System.out.println("Iteration " + iter + ": " + " " + currLoss); } if (OptimUtils.isConverged(prevLoss, currLoss, epsilon)) { break; } } if (refit) { boolean[] selected = new boolean[attrs.length]; for (int i = 0; i < selected.length; i++) { selected[i] = w[i] != 0; } return refitClassifier(attrs, selected, x, y, maxNumIters); } else { return GLMOptimUtils.getGLM(attrs, w, intercept, LinkFunction.LOGIT); } } /** * Builds an L1-regularized binary classifier on sparse inputs. Each row of the input represents a feature (instead * of a data point), i.e., in column-oriented format. This procedure does not assume the data is normalized or * centered. * * @param attrs the attribute list. * @param indices the indices. * @param values the values. * @param y the targets. * @param maxNumIters the maximum number of iterations. * @param lambda the lambda. * @return an L1-regularized classifier. */ public GLM buildBinaryClassifier(int[] attrs, int[][] indices, double[][] values, double[] y, int maxNumIters, double lambda) { double[] w = new double[attrs.length]; double intercept = 0; double[] pTrain = new double[y.length]; double[] rTrain = new double[y.length]; OptimUtils.computePseudoResidual(pTrain, y, rTrain); // Calculate theta's double[] theta = new double[values.length]; for (int i = 0; i < values.length; i++) { theta[i] = StatUtils.sumSq(values[i]) / 4; } // Coordinate descent final double tl1 = lambda * y.length; for (int iter = 0; iter < maxNumIters; iter++) { double prevLoss = GLMOptimUtils.computeLassoLoss(pTrain, y, w, lambda); if (fitIntercept) { intercept += OptimUtils.fitIntercept(pTrain, rTrain, y); } doOnePassBinomial(indices, values, theta, y, tl1, w, pTrain, rTrain); double currLoss = GLMOptimUtils.computeLassoLoss(pTrain, y, w, lambda); if (verbose) { System.out.println("Iteration " + iter + ": " + " " + currLoss); } if (OptimUtils.isConverged(prevLoss, currLoss, epsilon)) { break; } } if (refit) { boolean[] selected = new boolean[attrs.length]; for (int i = 0; i < selected.length; i++) { selected[i] = w[i] != 0; } return refitClassifier(attrs, selected, indices, values, y, maxNumIters); } else { return GLMOptimUtils.getGLM(attrs, w, intercept, LinkFunction.LOGIT); } } /** * Builds L1-regularized binary classifiers for a sequence of regularization parameter lambdas. Each row in the * input matrix x represents a feature (instead of a data point). Thus the input matrix is the transpose of the * row-oriented data matrix. This procedure does not assume the data is normalized or centered. * * @param attrs the attribute list. * @param x the values. * @param y the targets. * @param maxNumIters the maximum number of iterations. * @param numLambdas the number of lambdas. * @param minLambdaRatio the minimum lambda is minLambdaRatio * max lambda. * @return L1-regularized classifiers. */ public List buildBinaryClassifiers(int[] attrs, double[][] x, double[] y, int maxNumIters, int numLambdas, double minLambdaRatio) { double[] w = new double[attrs.length]; double intercept = 0; double[] pTrain = new double[y.length]; double[] rTrain = new double[y.length]; OptimUtils.computePseudoResidual(pTrain, y, rTrain); // Calculate theta's double[] theta = new double[x.length]; for (int i = 0; i < x.length; i++) { theta[i] = StatUtils.sumSq(x[i]) / 4; } double maxLambda = findMaxLambdaBinomial(x, y, pTrain, rTrain); // Dampening factor for lambda double alpha = Math.pow(minLambdaRatio, 1.0 / numLambdas); // Compute the regularization path List glms = new ArrayList<>(numLambdas); Set structures = new HashSet<>(); double lambda = maxLambda; for (int g = 0; g < numLambdas; g++) { // Coordinate descent final double tl1 = lambda * y.length; for (int iter = 0; iter < maxNumIters; iter++) { double prevLoss = GLMOptimUtils.computeLassoLoss(pTrain, y, w, lambda); if (fitIntercept) { intercept += OptimUtils.fitIntercept(pTrain, rTrain, y); } doOnePassBinomial(x, theta, y, tl1, w, pTrain, rTrain); double currLoss = GLMOptimUtils.computeLassoLoss(pTrain, y, w, lambda); if (verbose) { System.out.println("Iteration " + iter + ": " + " " + currLoss); } if (OptimUtils.isConverged(prevLoss, currLoss, epsilon)) { break; } } lambda *= alpha; if (refit) { boolean[] selected = new boolean[attrs.length]; for (int i = 0; i < selected.length; i++) { selected[i] = w[i] != 0; } ModelStructure structure = new ModelStructure(selected); if (!structures.contains(structure)) { GLM glm = refitClassifier(attrs, selected, x, y, maxNumIters); glms.add(glm); structures.add(structure); } } else { GLM glm = GLMOptimUtils.getGLM(attrs, w, intercept, LinkFunction.LOGIT); glms.add(glm); } } return glms; } /** * Builds L1-regularized binary classifiers for a sequence of regularization parameter lambdas on sparse format. * Each row of the input represents a feature (instead of a data point), i.e., in column-oriented format. This * procedure does not assume the data is normalized or centered. * * @param attrs the attribute list. * @param indices the indices. * @param values the values. * @param y the targets. * @param maxNumIters the maximum number of iterations. * @param numLambdas the number of lambdas. * @param minLambdaRatio the minimum lambda is minLambdaRatio * max lambda. * @return L1-regularized classifiers. */ public List buildBinaryClassifiers(int[] attrs, int[][] indices, double[][] values, double[] y, int maxNumIters, int numLambdas, double minLambdaRatio) { double[] w = new double[attrs.length]; double intercept = 0; double[] pTrain = new double[y.length]; double[] rTrain = new double[y.length]; OptimUtils.computePseudoResidual(pTrain, y, rTrain); // Calculate theta's double[] theta = new double[values.length]; for (int i = 0; i < values.length; i++) { theta[i] = StatUtils.sumSq(values[i]) / 4; } double maxLambda = findMaxLambdaBinomial(indices, values, y, pTrain, rTrain); // Dampening factor for lambda double alpha = Math.pow(minLambdaRatio, 1.0 / numLambdas); // Compute the regularization path List glms = new ArrayList<>(numLambdas); Set structures = new HashSet<>(); double lambda = maxLambda; for (int g = 0; g < numLambdas; g++) { // Coordinate descent final double tl1 = lambda * y.length; for (int iter = 0; iter < maxNumIters; iter++) { double prevLoss = GLMOptimUtils.computeLassoLoss(pTrain, y, w, lambda); if (fitIntercept) { intercept += OptimUtils.fitIntercept(pTrain, rTrain, y); } doOnePassBinomial(indices, values, theta, y, tl1, w, pTrain, rTrain); double currLoss = GLMOptimUtils.computeLassoLoss(pTrain, y, w, lambda); if (verbose) { System.out.println("Iteration " + iter + ": " + " " + currLoss); } if (OptimUtils.isConverged(prevLoss, currLoss, epsilon)) { break; } } lambda *= alpha; if (refit) { boolean[] selected = new boolean[attrs.length]; for (int i = 0; i < selected.length; i++) { selected[i] = w[i] != 0; } ModelStructure structure = new ModelStructure(selected); if (!structures.contains(structure)) { GLM glm = refitClassifier(attrs, selected, indices, values, y, maxNumIters); glms.add(glm); structures.add(structure); } } else { GLM glm = GLMOptimUtils.getGLM(attrs, w, intercept, LinkFunction.LOGIT); glms.add(glm); } } return glms; } /** * Builds an L1-regularized classifier. * * @param trainSet the training set. * @param isSparse {@code true} if the training set is treated as sparse. * @param maxNumIters the maximum number of iterations. * @param lambda the lambda. * @return an L1-regularized classifer. */ public GLM buildClassifier(Instances trainSet, boolean isSparse, int maxNumIters, double lambda) { Attribute classAttribute = trainSet.getTargetAttribute(); if (classAttribute.getType() != Attribute.Type.NOMINAL) { throw new IllegalArgumentException("Class attribute must be nominal."); } NominalAttribute clazz = (NominalAttribute) classAttribute; int numClasses = clazz.getCardinality(); if (isSparse) { SparseDataset sd = getSparseDataset(trainSet, true); int[] attrs = sd.attrs; int[][] indices = sd.indices; double[][] values = sd.values; double[] y = new double[sd.y.length]; double[] cList = sd.cList; if (numClasses == 2) { for (int i = 0; i < y.length; i++) { int label = (int) sd.y[i]; y[i] = label == 0 ? 1 : 0; } GLM glm = buildBinaryClassifier(attrs, indices, values, y, maxNumIters, lambda); double[] w = glm.w[0]; for (int j = 0; j < cList.length; j++) { int attIndex = sd.attrs[j]; w[attIndex] *= cList[j]; } return glm; } else { int p = attrs.length == 0 ? 0 : (StatUtils.max(attrs) + 1); GLM glm = new GLM(numClasses, p); for (int k = 0; k < numClasses; k++) { // One-vs-the-rest for (int i = 0; i < y.length; i++) { int label = (int) sd.y[i]; y[i] = label == k ? 1 : 0; } GLM binaryClassifier = buildBinaryClassifier(attrs, indices, values, y, maxNumIters, lambda); double[] w = binaryClassifier.w[0]; for (int j = 0; j < cList.length; j++) { int attIndex = attrs[j]; glm.w[k][attIndex] = w[attIndex] * cList[j]; } glm.intercept[k] = binaryClassifier.intercept[0]; } return glm; } } else { DenseDataset dd = getDenseDataset(trainSet, true); int[] attrs = dd.attrs; double[][] x = dd.x; double[] y = new double[dd.y.length]; double[] cList = dd.cList; if (numClasses == 2) { for (int i = 0; i < y.length; i++) { int label = (int) dd.y[i]; y[i] = label == 0 ? 1 : 0; } GLM glm = buildBinaryClassifier(attrs, x, y, maxNumIters, lambda); double[] w = glm.w[0]; for (int j = 0; j < cList.length; j++) { int attIndex = attrs[j]; w[attIndex] *= cList[j]; } return glm; } else { int p = attrs.length == 0 ? 0 : (StatUtils.max(attrs) + 1); GLM glm = new GLM(numClasses, p); for (int k = 0; k < numClasses; k++) { // One-vs-the-rest for (int i = 0; i < y.length; i++) { int label = (int) dd.y[i]; y[i] = label == k ? 1 : 0; } GLM binaryClassifier = buildBinaryClassifier(attrs, x, y, maxNumIters, lambda); double[] w = binaryClassifier.w[0]; for (int j = 0; j < cList.length; j++) { int attIndex = attrs[j]; glm.w[k][attIndex] = w[attIndex] * cList[j]; } glm.intercept[k] = binaryClassifier.intercept[0]; } return glm; } } } /** * Builds an L1-regularized classifier. * * @param trainSet the training set. * @param maxNumIters the maximum number of iterations. * @param lambda the lambda. * @return an L1-regularized classifer. */ public GLM buildClassifier(Instances trainSet, int maxNumIters, double lambda) { return buildClassifier(trainSet, isSparse(trainSet), maxNumIters, lambda); } /** * Builds L1-regularized classifiers for a sequence of regularization parameter lambdas. * * @param trainSet the training set. * @param isSparse {@code true} if the training set is treated as sparse. * @param maxNumIters the maximum number of iterations. * @param numLambdas the number of lambdas. * @param minLambdaRatio the minimum lambda is minLambdaRatio * max lambda. * @return L1-regularized classifiers. */ public List buildClassifiers(Instances trainSet, boolean isSparse, int maxNumIters, int numLambdas, double minLambdaRatio) { Attribute classAttribute = trainSet.getTargetAttribute(); if (classAttribute.getType() != Attribute.Type.NOMINAL) { throw new IllegalArgumentException("Class attribute must be nominal."); } NominalAttribute clazz = (NominalAttribute) classAttribute; int numClasses = clazz.getCardinality(); if (isSparse) { SparseDataset sd = getSparseDataset(trainSet, true); int[] attrs = sd.attrs; int[][] indices = sd.indices; double[][] values = sd.values; double[] y = new double[sd.y.length]; double[] cList = sd.cList; if (numClasses == 2) { for (int i = 0; i < y.length; i++) { int label = (int) sd.y[i]; y[i] = label == 0 ? 1 : 0; } List glms = buildBinaryClassifiers(attrs, indices, values, y, maxNumIters, numLambdas, minLambdaRatio); for (GLM glm : glms) { double[] w = glm.w[0]; for (int j = 0; j < cList.length; j++) { int attIndex = sd.attrs[j]; w[attIndex] *= cList[j]; } } return glms; } else { boolean refit = this.refit; this.refit = false; // Not supported in multiclass // classification int p = attrs.length == 0 ? 0 : (StatUtils.max(attrs) + 1); List glms = new ArrayList<>(); for (int i = 0; i < numLambdas; i++) { GLM glm = new GLM(numClasses, p); glms.add(glm); } for (int k = 0; k < numClasses; k++) { // One-vs-the-rest for (int i = 0; i < y.length; i++) { int label = (int) sd.y[i]; y[i] = label == k ? 1 : 0; } List binaryClassifiers = buildBinaryClassifiers(attrs, indices, values, y, maxNumIters, numLambdas, minLambdaRatio); for (int l = 0; l < numLambdas; l++) { GLM binaryClassifier = binaryClassifiers.get(l); GLM glm = glms.get(l); double[] w = binaryClassifier.w[0]; for (int j = 0; j < cList.length; j++) { int attIndex = attrs[j]; glm.w[k][attIndex] = w[attIndex] * cList[j]; } glm.intercept[k] = binaryClassifier.intercept[0]; } } this.refit = refit; return glms; } } else { DenseDataset dd = getDenseDataset(trainSet, true); int[] attrs = dd.attrs; double[][] x = dd.x; double[] y = new double[dd.y.length]; double[] cList = dd.cList; if (numClasses == 2) { for (int i = 0; i < y.length; i++) { int label = (int) dd.y[i]; y[i] = label == 0 ? 1 : 0; } List glms = buildBinaryClassifiers(attrs, x, y, maxNumIters, numLambdas, minLambdaRatio); for (GLM glm : glms) { double[] w = glm.w[0]; for (int j = 0; j < cList.length; j++) { int attIndex = dd.attrs[j]; w[attIndex] *= cList[j]; } } return glms; } else { boolean refit = this.refit; this.refit = false; // Not supported in multiclass // classification int p = attrs.length == 0 ? 0 : (StatUtils.max(attrs) + 1); List glms = new ArrayList<>(); for (int i = 0; i < numLambdas; i++) { GLM glm = new GLM(numClasses, p); glms.add(glm); } for (int k = 0; k < numClasses; k++) { // One-vs-the-rest for (int i = 0; i < y.length; i++) { int label = (int) dd.y[i]; y[i] = label == k ? 1 : 0; } List binaryClassifiers = buildBinaryClassifiers(attrs, x, y, maxNumIters, numLambdas, minLambdaRatio); for (int l = 0; l < numLambdas; l++) { GLM binaryClassifier = binaryClassifiers.get(l); GLM glm = glms.get(l); double[] w = binaryClassifier.w[0]; for (int j = 0; j < cList.length; j++) { int attIndex = attrs[j]; glm.w[k][attIndex] = w[attIndex] * cList[j]; } glm.intercept[k] = binaryClassifier.intercept[0]; } } this.refit = refit; return glms; } } } /** * Builds L1-regularized classifiers for a sequence of regularization parameter lambdas. * * @param trainSet the training set. * @param maxNumIters the maximum number of iterations. * @param numLambdas the number of lambdas. * @param minLambdaRatio the minimum lambda is minLambdaRatio * max lambda. * @return L1-regularized classifiers. */ public List buildClassifiers(Instances trainSet, int maxNumIters, int numLambdas, double minLambdaRatio) { return buildClassifiers(trainSet, isSparse(trainSet), maxNumIters, numLambdas, minLambdaRatio); } /** * Builds an L1-regularized regressor. * * @param trainSet the training set. * @param isSparse {@code true} if the training set is treated as sparse. * @param maxNumIters the maximum number of iterations. * @param lambda the lambda. * @return an L1-regularized regressor. */ public GLM buildGaussianRegressor(Instances trainSet, boolean isSparse, int maxNumIters, double lambda) { if (isSparse) { SparseDataset sd = getSparseDataset(trainSet, true); double[] cList = sd.cList; GLM glm = buildGaussianRegressor(sd.attrs, sd.indices, sd.values, sd.y, maxNumIters, lambda); double[] w = glm.w[0]; for (int j = 0; j < cList.length; j++) { int attIndex = sd.attrs[j]; w[attIndex] *= cList[j]; } return glm; } else { DenseDataset dd = getDenseDataset(trainSet, true); double[] cList = dd.cList; GLM glm = buildGaussianRegressor(dd.attrs, dd.x, dd.y, maxNumIters, lambda); double[] w = glm.w[0]; for (int j = 0; j < cList.length; j++) { int attIndex = dd.attrs[j]; w[attIndex] *= cList[j]; } return glm; } } /** * Builds an L1-regularized regressor. * * @param trainSet the training set. * @param maxNumIters the maximum number of iterations. * @param lambda the lambda. * @return an L1-regularized regressor. */ public GLM buildGaussianRegressor(Instances trainSet, int maxNumIters, double lambda) { return buildGaussianRegressor(trainSet, isSparse(trainSet), maxNumIters, lambda); } /** * Builds an L1-regularized regressor. Each row in the input matrix x represents a feature (instead of a data * point). Thus the input matrix is the transpose of the row-oriented data matrix. This procedure does not assume * the data is normalized or centered. * * @param attrs the attribute list. * @param x the inputs. * @param y the targets. * @param maxNumIters the maximum number of iterations. * @param lambda the lambda. * @return an L1-regularized penalized regressor. */ public GLM buildGaussianRegressor(int[] attrs, double[][] x, double[] y, int maxNumIters, double lambda) { double[] w = new double[attrs.length]; double intercept = 0; // Initialize residuals double[] rTrain = new double[y.length]; for (int i = 0; i < rTrain.length; i++) { rTrain[i] = y[i]; } // Calculate sum of squares double[] sq = new double[x.length]; for (int j = 0; j < x.length; j++) { sq[j] = StatUtils.sumSq(x[j]); } // Coordinate descent final double tl1 = lambda * y.length; for (int iter = 0; iter < maxNumIters; iter++) { double prevLoss = GLMOptimUtils.computeLassoLoss(rTrain, w, lambda); if (fitIntercept) { intercept += OptimUtils.fitIntercept(rTrain); } doOnePassGaussian(x, sq, tl1, w, rTrain); double currLoss = GLMOptimUtils.computeLassoLoss(rTrain, w, lambda); if (verbose) { System.out.println("Iteration " + iter + ": " + " " + currLoss); } if (OptimUtils.isConverged(prevLoss, currLoss, epsilon)) { break; } } if (refit) { boolean[] selected = new boolean[attrs.length]; for (int i = 0; i < selected.length; i++) { selected[i] = w[i] != 0; } return refitGaussianRegressor(attrs, selected, x, y, maxNumIters); } else { return GLMOptimUtils.getGLM(attrs, w, intercept, LinkFunction.IDENTITY); } } /** * Builds an L1-regularized regressor on sparse inputs. Each row of the input represents a feature (instead of a * data point), i.e., in column-oriented format. This procedure does not assume the data is normalized or centered. * * @param attrs the attribute list. * @param indices the indices. * @param values the values. * @param y the targets. * @param maxNumIters the maximum number of iterations. * @param lambda the lambda. * @return an L1-regularized regressor. */ public GLM buildGaussianRegressor(int[] attrs, int[][] indices, double[][] values, double[] y, int maxNumIters, double lambda) { double[] w = new double[attrs.length]; double intercept = 0; // Initialize residuals double[] rTrain = new double[y.length]; for (int i = 0; i < rTrain.length; i++) { rTrain[i] = y[i]; } // Calculate sum of squares double[] sq = new double[values.length]; for (int j = 0; j < values.length; j++) { sq[j] = StatUtils.sumSq(values[j]); } // Coordinate descent final double tl1 = lambda * y.length; for (int iter = 0; iter < maxNumIters; iter++) { double prevLoss = GLMOptimUtils.computeLassoLoss(rTrain, w, lambda); if (fitIntercept) { intercept += OptimUtils.fitIntercept(rTrain); } doOnePassGaussian(indices, values, sq, tl1, w, rTrain); double currLoss = GLMOptimUtils.computeLassoLoss(rTrain, w, lambda); if (verbose) { System.out.println("Iteration " + iter + ": " + " " + currLoss); } if (OptimUtils.isConverged(prevLoss, currLoss, epsilon)) { break; } } if (refit) { boolean[] selected = new boolean[attrs.length]; for (int i = 0; i < selected.length; i++) { selected[i] = w[i] != 0; } return refitGaussianRegressor(attrs, selected, indices, values, y, maxNumIters); } else { return GLMOptimUtils.getGLM(attrs, w, intercept, LinkFunction.IDENTITY); } } /** * Builds L1-regularized regressors for a sequence of regularization parameter lambdas. * * @param trainSet the training set. * @param isSparse {@code true} if the training set is treated as sparse. * @param maxNumIters the maximum number of iterations. * @param numLambdas the number of lambdas. * @param minLambdaRatio the minimum lambda is minLambdaRatio * max lambda. * @return L1-regularized regressors. */ public List buildGaussianRegressors(Instances trainSet, boolean isSparse, int maxNumIters, int numLambdas, double minLambdaRatio) { if (isSparse) { SparseDataset sd = getSparseDataset(trainSet, true); double[] cList = sd.cList; List glms = buildGaussianRegressors(sd.attrs, sd.indices, sd.values, sd.y, maxNumIters, numLambdas, minLambdaRatio); for (GLM glm : glms) { double[] w = glm.w[0]; for (int i = 0; i < cList.length; i++) { int attIndex = sd.attrs[i]; w[attIndex] *= cList[i]; } } return glms; } else { DenseDataset dd = getDenseDataset(trainSet, true); double[] cList = dd.cList; List glms = buildGaussianRegressors(dd.attrs, dd.x, dd.y, maxNumIters, numLambdas, minLambdaRatio); for (GLM glm : glms) { double[] w = glm.w[0]; for (int j = 0; j < cList.length; j++) { int attIndex = dd.attrs[j]; w[attIndex] *= cList[j]; } } return glms; } } /** * Builds L1-regularized regressors for a sequence of regularization parameter lambdas. * * @param trainSet the training set. * @param maxNumIters the maximum number of iterations. * @param numLambdas the number of lambdas. * @param minLambdaRatio the minimum lambda is minLambdaRatio * max lambda. * @return L1-regularized regressors. */ public List buildGaussianRegressors(Instances trainSet, int maxNumIters, int numLambdas, double minLambdaRatio) { return buildGaussianRegressors(trainSet, isSparse(trainSet), maxNumIters, numLambdas, minLambdaRatio); } /** * Builds L1-regularized regressors for a sequence of regularization parameter lambdas. Each row in the input matrix * x represents a feature (instead of a data point). Thus the input matrix is the transpose of the row-oriented data * matrix. This procedure does not assume the data is normalized or centered. * * @param attrs the attribute list. * @param x the inputs. * @param y the targets. * @param maxNumIters the maximum number of iterations. * @param numLambdas the number of lambdas. * @param minLambdaRatio the minimum lambda is minLambdaRatio * max lambda. * @return L1-regularized regressors. */ public List buildGaussianRegressors(int[] attrs, double[][] x, double[] y, int maxNumIters, int numLambdas, double minLambdaRatio) { double[] w = new double[attrs.length]; double intercept = 0; // Backup targets double[] rTrain = new double[y.length]; for (int i = 0; i < rTrain.length; i++) { rTrain[i] = y[i]; } // Calculate sum of squares double[] sq = new double[x.length]; for (int i = 0; i < x.length; i++) { sq[i] = StatUtils.sumSq(x[i]); } // Determine max lambda double maxLambda = findMaxLambdaGaussian(x, rTrain); // Dampening factor for lambda double alpha = Math.pow(minLambdaRatio, 1.0 / numLambdas); // Compute the regularization path List glms = new ArrayList<>(numLambdas); Set structures = new HashSet<>(); double lambda = maxLambda; for (int g = 0; g < numLambdas; g++) { final double tl1 = lambda * y.length; // Coordinate descent for (int iter = 0; iter < maxNumIters; iter++) { double prevLoss = GLMOptimUtils.computeLassoLoss(rTrain, w, lambda); if (fitIntercept) { intercept += OptimUtils.fitIntercept(rTrain); } doOnePassGaussian(x, sq, tl1, w, rTrain); double currLoss = GLMOptimUtils.computeLassoLoss(rTrain, w, lambda); if (verbose) { System.out.println("Iteration " + iter + ": " + " " + currLoss); } if (OptimUtils.isConverged(prevLoss, currLoss, epsilon)) { break; } } lambda *= alpha; if (refit) { boolean[] selected = new boolean[attrs.length]; for (int i = 0; i < selected.length; i++) { selected[i] = w[i] != 0; } ModelStructure structure = new ModelStructure(selected); if (!structures.contains(structure)) { GLM glm = refitGaussianRegressor(attrs, selected, x, y, maxNumIters); glms.add(glm); structures.add(structure); } } else { GLM glm = GLMOptimUtils.getGLM(attrs, w, intercept, LinkFunction.IDENTITY); glms.add(glm); } } return glms; } /** * Builds L1-regularized regressors for a sequence of regularization parameter lambdas on sparse format. Each row of * the input represents a feature (instead of a data point), i.e., in column-oriented format. This procedure does * not assume the data is normalized or centered. * * @param attrs the attribute list. * @param indices the indices. * @param values the values. * @param y the targets. * @param maxNumIters the maximum number of iterations. * @param numLambdas the number of lambdas. * @param minLambdaRatio the minimum lambda is minLambdaRatio * max lambda. * @return L1-regularized regressors. */ public List buildGaussianRegressors(int[] attrs, int[][] indices, double[][] values, double[] y, int maxNumIters, int numLambdas, double minLambdaRatio) { double[] w = new double[attrs.length]; double intercept = 0; // Backup targets double[] rTrain = new double[y.length]; for (int i = 0; i < rTrain.length; i++) { rTrain[i] = y[i]; } // Calculate sum of squares double[] sq = new double[values.length]; for (int i = 0; i < values.length; i++) { sq[i] = StatUtils.sumSq(values[i]); } // Determine max lambda double maxLambda = findMaxLambdaGaussian(indices, values, y); // Dampening factor for lambda double alpha = Math.pow(minLambdaRatio, 1.0 / numLambdas); // Compute the regularization path List glms = new ArrayList<>(numLambdas); Set structures = new HashSet<>(); double lambda = maxLambda; for (int g = 0; g < numLambdas; g++) { final double tl1 = lambda * y.length; // Coordinate descent for (int iter = 0; iter < maxNumIters; iter++) { double prevLoss = GLMOptimUtils.computeLassoLoss(rTrain, w, lambda); if (fitIntercept) { intercept += OptimUtils.fitIntercept(rTrain); } doOnePassGaussian(indices, values, sq, tl1, w, rTrain); double currLoss = GLMOptimUtils.computeLassoLoss(rTrain, w, lambda); if (verbose) { System.out.println("Iteration " + iter + ": " + " " + currLoss); } if (OptimUtils.isConverged(prevLoss, currLoss, epsilon)) { break; } } lambda *= alpha; if (refit) { boolean[] selected = new boolean[attrs.length]; for (int i = 0; i < selected.length; i++) { selected[i] = w[i] != 0; } ModelStructure structure = new ModelStructure(selected); if (!structures.contains(structure)) { GLM glm = refitGaussianRegressor(attrs, selected, indices, values, y, maxNumIters); glms.add(glm); structures.add(structure); } } else { GLM glm = GLMOptimUtils.getGLM(attrs, w, intercept, LinkFunction.IDENTITY); glms.add(glm); } } return glms; } protected void doOnePassGaussian(double[][] x, double[] sq, final double tl1, double[] w, double[] rTrain) { for (int j = 0; j < x.length; j++) { double[] v = x[j]; // Calculate weight updates using naive updates double wNew = w[j] * sq[j] + VectorUtils.dotProduct(v, rTrain); if (Math.abs(wNew) <= tl1) { wNew = 0; } else if (wNew > 0) { wNew -= tl1; } else { wNew += tl1; } wNew /= sq[j]; double delta = wNew - w[j]; w[j] = wNew; // Update residuals for (int i = 0; i < rTrain.length; i++) { rTrain[i] -= delta * v[i]; } } } protected void doOnePassGaussian(int[][] indices, double[][] values, double[] sq, final double tl1, double[] w, double[] rTrain) { for (int j = 0; j < indices.length; j++) { // Calculate weight updates using naive updates double wNew = w[j] * sq[j]; int[] index = indices[j]; double[] value = values[j]; for (int i = 0; i < index.length; i++) { wNew += rTrain[index[i]] * value[i]; } if (Math.abs(wNew) <= tl1) { wNew = 0; } else if (wNew > 0) { wNew -= tl1; } else { wNew += tl1; } wNew /= sq[j]; double delta = wNew - w[j]; w[j] = wNew; // Update residuals for (int i = 0; i < index.length; i++) { rTrain[index[i]] -= delta * value[i]; } } } protected void doOnePassBinomial(double[][] x, double[] theta, double[] y, final double tl1, double[] w, double[] pTrain, double[] rTrain) { for (int j = 0; j < x.length; j++) { if (Math.abs(theta[j]) <= MathUtils.EPSILON) { continue; } double[] v = x[j]; double eta = VectorUtils.dotProduct(rTrain, v); double newW = w[j] + eta / theta[j]; double t = tl1 / theta[j]; if (newW > t) { newW -= t; } else if (newW < -t) { newW += t; } else { newW = 0; } double delta = newW - w[j]; w[j] = newW; // Update predictions for (int i = 0; i < pTrain.length; i++) { pTrain[i] += delta * v[i]; rTrain[i] = OptimUtils.getPseudoResidual(pTrain[i], y[i]); } } } protected void doOnePassBinomial(int[][] indices, double[][] values, double[] theta, double[] y, final double tl1, double[] w, double[] pTrain, double[] rTrain) { for (int j = 0; j < indices.length; j++) { if (Math.abs(theta[j]) <= MathUtils.EPSILON) { continue; } double eta = 0; int[] index = indices[j]; double[] value = values[j]; for (int i = 0; i < index.length; i++) { int idx = index[i]; eta += rTrain[idx] * value[i]; } double newW = w[j] + eta / theta[j]; double t = tl1 / theta[j]; if (newW > t) { newW -= t; } else if (newW < -t) { newW += t; } else { newW = 0; } double delta = newW - w[j]; w[j] = newW; // Update predictions for (int i = 0; i < index.length; i++) { int idx = index[i]; pTrain[idx] += delta * value[i]; rTrain[idx] = OptimUtils.getPseudoResidual(pTrain[idx], y[idx]); } } } protected double findMaxLambdaGaussian(double[][] x, double[] y) { double mean = 0; if (fitIntercept) { mean = OptimUtils.fitIntercept(y); } // Determine max lambda double maxLambda = 0; for (double[] col : x) { double dot = Math.abs(VectorUtils.dotProduct(col, y)); if (dot > maxLambda) { maxLambda = dot; } } maxLambda /= y.length; if (fitIntercept) { VectorUtils.add(y, mean); } return maxLambda; } protected double findMaxLambdaGaussian(int[][] indices, double[][] values, double[] y) { double mean = 0; if (fitIntercept) { mean = OptimUtils.fitIntercept(y); } DenseVector v = new DenseVector(y); // Determine max lambda double maxLambda = 0; for (int i = 0; i < indices.length; i++) { int[] index = indices[i]; double[] value = values[i]; double dot = Math.abs(VectorUtils.dotProduct(new SparseVector(index, value), v)); if (dot > maxLambda) { maxLambda = dot; } } maxLambda /= y.length; if (fitIntercept) { VectorUtils.add(y, mean); } return maxLambda; } protected double findMaxLambdaBinomial(double[][] x, double[] y, double[] pTrain, double[] rTrain) { if (fitIntercept) { OptimUtils.fitIntercept(pTrain, rTrain, y); } double maxLambda = 0; for (double[] col : x) { double eta = 0; for (int i = 0; i < col.length; i++) { double r = OptimUtils.getPseudoResidual(pTrain[i], y[i]); r *= col[i]; eta += r; } double t = Math.abs(eta); if (t > maxLambda) { maxLambda = t; } } maxLambda /= y.length; if (fitIntercept) { Arrays.fill(pTrain, 0); OptimUtils.computePseudoResidual(pTrain, y, rTrain); } return maxLambda; } protected double findMaxLambdaBinomial(int[][] indices, double[][] values, double[] y, double[] pTrain, double[] rTrain) { if (fitIntercept) { OptimUtils.fitIntercept(pTrain, rTrain, y); } double maxLambda = 0; for (int k = 0; k < values.length; k++) { double eta = 0; int[] index = indices[k]; double[] value = values[k]; for (int i = 0; i < index.length; i++) { int idx = index[i]; double r = OptimUtils.getPseudoResidual(pTrain[idx], y[idx]); r *= value[i]; eta += r; } double t = Math.abs(eta); if (t > maxLambda) { maxLambda = t; } } maxLambda /= y.length; if (fitIntercept) { Arrays.fill(pTrain, 0); OptimUtils.computePseudoResidual(pTrain, y, rTrain); } return maxLambda; } /** * Returns the lambda. * * @return the lambda. */ public double getLambda() { return lambda; } /** * Returns the number of lambdas. * * @return the number of lambdas. */ public int getNumLambdas() { return numLambdas; } /** * Returns the task of this learner. * * @return the task of this learner. */ public Task getTask() { return task; } /** * Returns {@code true} if we refit the model. * * @return {@code true} if we refit the model. */ public boolean refit() { return refit; } /** * Sets whether we refit the model. * * @param refit {@code true} if we refit the model. */ public void refit(boolean refit) { this.refit = refit; } protected GLM refitGaussianRegressor(int[] attrs, boolean[] selected, double[][] x, double[] y, int maxNumIters) { List xList = new ArrayList<>(); for (int j = 0; j < attrs.length; j++) { if (selected[j]) { xList.add(x[j]); } } if (xList.size() == 0) { GLM glm = new GLM(0); if (fitIntercept) { glm.intercept[0] = StatUtils.mean(y); } return glm; } double[][] xNew = new double[xList.size()][]; for (int i = 0; i < xNew.length; i++) { xNew[i] = xList.get(i); } int[] attrsNew = new int[xNew.length]; for (int i = 0; i < attrsNew.length; i++) { attrsNew[i] = i; } RidgeLearner learner = new RidgeLearner(); learner.setVerbose(verbose); learner.setEpsilon(epsilon); learner.fitIntercept(fitIntercept); // A ridge regression with very small regularization parameter // This often improves stability a lot GLM glm = learner.buildGaussianRegressor(attrsNew, xNew, y, maxNumIters, 1e-8); double[] w = new double[attrs.length]; double[] coef = glm.coefficients(0); int k = 0; for (int j = 0; j < w.length; j++) { if (selected[j]) { w[j] = coef[k++]; } } return GLMOptimUtils.getGLM(attrs, w, glm.intercept(0), LinkFunction.IDENTITY); } protected GLM refitGaussianRegressor(int[] attrs, boolean[] selected, int[][] indices, double[][] values, double[] y, int maxNumIters) { List indicesList = new ArrayList<>(); List valuesList = new ArrayList<>(); for (int j = 0; j < attrs.length; j++) { if (selected[j]) { indicesList.add(indices[j]); valuesList.add(values[j]); } } if (indicesList.size() == 0) { GLM glm = new GLM(0); if (fitIntercept) { glm.intercept[0] = StatUtils.mean(y); } return glm; } int[][] indicesNew = new int[indicesList.size()][]; for (int i = 0; i < indicesNew.length; i++) { indicesNew[i] = indicesList.get(i); } double[][] valuesNew = new double[valuesList.size()][]; for (int i = 0; i < indicesNew.length; i++) { valuesNew[i] = valuesList.get(i); } int[] attrsNew = new int[indicesNew.length]; for (int i = 0; i < attrsNew.length; i++) { attrsNew[i] = i; } RidgeLearner learner = new RidgeLearner(); learner.setVerbose(verbose); learner.setEpsilon(epsilon); learner.fitIntercept(fitIntercept); // A ridge regression with very small regularization parameter // This often improves stability a lot GLM glm = learner.buildGaussianRegressor(attrsNew, indicesNew, valuesNew, y, maxNumIters, 1e-8); double[] w = new double[attrs.length]; double[] coef = glm.coefficients(0); int k = 0; for (int j = 0; j < w.length; j++) { if (selected[j]) { w[j] = coef[k++]; } } return GLMOptimUtils.getGLM(attrs, w, glm.intercept(0), LinkFunction.IDENTITY); } protected GLM refitClassifier(int[] attrs, boolean[] selected, double[][] x, double[] y, int maxNumIters) { List xList = new ArrayList<>(); for (int j = 0; j < attrs.length; j++) { if (selected[j]) { xList.add(x[j]); } } double[][] xNew = new double[xList.size()][]; for (int i = 0; i < xNew.length; i++) { xNew[i] = xList.get(i); } int[] attrsNew = new int[xNew.length]; for (int i = 0; i < attrsNew.length; i++) { attrsNew[i] = i; } RidgeLearner learner = new RidgeLearner(); learner.setVerbose(verbose); learner.setEpsilon(epsilon); learner.fitIntercept(fitIntercept); // A ridge regression with very small regularization parameter // This often improves stability a lot GLM glm = learner.buildBinaryClassifier(attrsNew, xNew, y, maxNumIters, 1e-8); double[] w = new double[attrs.length]; double[] coef = glm.coefficients(0); int k = 0; for (int j = 0; j < w.length; j++) { if (selected[j]) { w[j] = coef[k++]; } } return GLMOptimUtils.getGLM(attrs, w, glm.intercept(0), LinkFunction.LOGIT); } protected GLM refitClassifier(int[] attrs, boolean[] selected, int[][] indices, double[][] values, double[] y, int maxNumIters) { List indicesList = new ArrayList<>(); List valuesList = new ArrayList<>(); for (int j = 0; j < attrs.length; j++) { if (selected[j]) { indicesList.add(indices[j]); valuesList.add(values[j]); } } int[][] indicesNew = new int[indicesList.size()][]; for (int i = 0; i < indicesNew.length; i++) { indicesNew[i] = indicesList.get(i); } double[][] valuesNew = new double[valuesList.size()][]; for (int i = 0; i < indicesNew.length; i++) { valuesNew[i] = valuesList.get(i); } int[] attrsNew = new int[indicesNew.length]; for (int i = 0; i < attrsNew.length; i++) { attrsNew[i] = i; } RidgeLearner learner = new RidgeLearner(); learner.setVerbose(verbose); learner.setEpsilon(epsilon); learner.fitIntercept(fitIntercept); // A ridge regression with very small regularization parameter // This often improves stability a lot GLM glm = learner.buildBinaryClassifier(attrsNew, indicesNew, valuesNew, y, maxNumIters, 1e-8); double[] w = new double[attrs.length]; double[] coef = glm.coefficients(0); int k = 0; for (int j = 0; j < w.length; j++) { if (selected[j]) { w[j] = coef[k++]; } } return GLMOptimUtils.getGLM(attrs, w, glm.intercept(0), LinkFunction.LOGIT); } /** * Sets the lambda. * * @param lambda the lambda. */ public void setLambda(double lambda) { this.lambda = lambda; } /** * Sets the number of lambdas. * * @param numLambdas the number of lambdas. */ public void setNumLambdas(int numLambdas) { this.numLambdas = numLambdas; } /** * Sets the task of this learner. * * @param task the task of this learner. */ public void setTask(Task task) { this.task = task; } } ================================================ FILE: src/main/java/mltk/predictor/glm/RidgeLearner.java ================================================ package mltk.predictor.glm; import java.util.Arrays; import mltk.cmdline.Argument; import mltk.cmdline.CmdLineParser; import mltk.cmdline.options.LearnerWithTaskOptions; import mltk.core.Attribute; import mltk.core.Instances; import mltk.core.NominalAttribute; import mltk.core.io.InstancesReader; import mltk.predictor.Family; import mltk.predictor.LinkFunction; import mltk.predictor.io.PredictorWriter; import mltk.util.MathUtils; import mltk.util.OptimUtils; import mltk.util.StatUtils; import mltk.util.VectorUtils; /** * Class for learning L2-regularized linear model via coordinate descent. * * @author Yin Lou * */ public class RidgeLearner extends GLMLearner { static class Options extends LearnerWithTaskOptions { @Argument(name = "-m", description = "maximum num of iterations (default: 0)") int maxIter = 0; @Argument(name = "-l", description = "lambda (default: 0)") double lambda = 0; } /** * Trains L2-regularized GLMs. * *
	 * Usage: mltk.predictor.glm.RidgeLearner
	 * -t	train set path
	 * [-g]	task between classification (c) and regression (r) (default: r)
	 * [-r]	attribute file path
	 * [-o]	output model path
	 * [-V]	verbose (default: true)
	 * [-m]	maximum num of iterations (default: 0)
	 * [-l]	lambda (default: 0)
	 * 
* * @param args the command line arguments. * @throws Exception */ public static void main(String[] args) throws Exception { Options opts = new Options(); CmdLineParser parser = new CmdLineParser(RidgeLearner.class, opts); Task task = null; try { parser.parse(args); task = Task.get(opts.task); } catch (IllegalArgumentException e) { parser.printUsage(); System.exit(1); } Instances trainSet = InstancesReader.read(opts.attPath, opts.trainPath); RidgeLearner learner = new RidgeLearner(); learner.setVerbose(opts.verbose); learner.setTask(task); learner.setLambda(opts.lambda); learner.setMaxNumIters(opts.maxIter); long start = System.currentTimeMillis(); GLM glm = learner.build(trainSet); long end = System.currentTimeMillis(); System.out.println("Time: " + (end - start) / 1000.0); if (opts.outputModelPath != null) { PredictorWriter.write(glm, opts.outputModelPath); } } protected double lambda; protected Task task; /** * Constructor. */ public RidgeLearner() { lambda = 0; // no regularization task = Task.REGRESSION; } @Override public GLM build(Instances instances) { GLM glm = null; if (maxNumIters < 0) { maxNumIters = 20; } switch (task) { case REGRESSION: glm = buildGaussianRegressor(instances, maxNumIters, lambda); break; case CLASSIFICATION: glm = buildClassifier(instances, maxNumIters, lambda); break; default: break; } return glm; } @Override public GLM build(Instances trainSet, Family family) { GLM glm = null; if (maxNumIters < 0) { maxNumIters = 20; } switch (family) { case GAUSSIAN: glm = buildGaussianRegressor(trainSet, maxNumIters, lambda); break; case BINOMIAL: glm = buildClassifier(trainSet, maxNumIters, lambda); break; default: throw new IllegalArgumentException("Unsupported family: " + family); } return glm; } /** * Builds an L2-regularized binary classifier. Each row in the input matrix x represents a feature (instead of a * data point). Thus the input matrix is the transpose of the row-oriented data matrix. This procedure does not * assume the data is normalized or centered. * * @param attrs the attribute list. * @param x the inputs. * @param y the targets. * @param maxNumIters the maximum number of iterations. * @param lambda the lambda. * @return an L2-regularized binary classifier. */ public GLM buildBinaryClassifier(int[] attrs, double[][] x, double[] y, int maxNumIters, double lambda) { double[] w = new double[attrs.length]; double intercept = 0; double[] pTrain = new double[y.length]; double[] rTrain = new double[y.length]; OptimUtils.computePseudoResidual(pTrain, y, rTrain); // Calculate theta's double[] theta = new double[x.length]; for (int i = 0; i < x.length; i++) { theta[i] = StatUtils.sumSq(x[i]) / 4; } // Coordinate gradient descent final double tl2 = lambda * y.length; for (int iter = 0; iter < maxNumIters; iter++) { double prevLoss = GLMOptimUtils.computeRidgeLoss(pTrain, y, w, lambda); if (fitIntercept) { intercept += OptimUtils.fitIntercept(pTrain, rTrain, y); } doOnePassBinomial(x, theta, y, tl2, w, pTrain, rTrain); double currLoss = GLMOptimUtils.computeRidgeLoss(pTrain, y, w, lambda); if (verbose) { System.out.println("Iteration " + iter + ": " + " " + currLoss); } if (OptimUtils.isConverged(prevLoss, currLoss, epsilon)) { break; } } return GLMOptimUtils.getGLM(attrs, w, intercept, LinkFunction.LOGIT); } /** * Builds an L2-regularized binary classifier on sparse inputs. Each row of the input represents a feature (instead * of a data point), i.e., in column-oriented format. This procedure does not assume the data is normalized or * centered. * * @param attrs the attribute list. * @param indices the indices. * @param values the values. * @param y the targets. * @param maxNumIters the maximum number of iterations. * @param lambda the lambda. * @return an L2-regularized classifier. */ public GLM buildBinaryClassifier(int[] attrs, int[][] indices, double[][] values, double[] y, int maxNumIters, double lambda) { double[] w = new double[attrs.length]; double intercept = 0; double[] pTrain = new double[y.length]; double[] rTrain = new double[y.length]; OptimUtils.computePseudoResidual(pTrain, y, rTrain); // Calculate theta's double[] theta = new double[values.length]; for (int i = 0; i < values.length; i++) { theta[i] = StatUtils.sumSq(values[i]) / 4; } // Coordinate gradient descent final double tl2 = lambda * y.length; for (int iter = 0; iter < maxNumIters; iter++) { double prevLoss = GLMOptimUtils.computeRidgeLoss(pTrain, y, w, lambda); if (fitIntercept) { intercept += OptimUtils.fitIntercept(pTrain, rTrain, y); } doOnePassBinomial(indices, values, theta, y, tl2, w, pTrain, rTrain); double currLoss = GLMOptimUtils.computeRidgeLoss(pTrain, y, w, lambda); if (verbose) { System.out.println("Iteration " + iter + ": " + " " + currLoss); } if (OptimUtils.isConverged(prevLoss, currLoss, epsilon)) { break; } } return GLMOptimUtils.getGLM(attrs, w, intercept, LinkFunction.LOGIT); } /** * Builds L2-regularized binary classifiers for a sequence of regularization parameter lambdas. Each row in the * input matrix x represents a feature (instead of a data point). Thus the input matrix is the transpose of the * row-oriented data matrix. This procedure does not assume the data is normalized or centered. * * @param attrs the attribute list. * @param x the values. * @param y the targets. * @param maxNumIters the maximum number of iterations. * @param lambdas the lambdas array. * @return L2-regularized classifiers. */ public GLM[] buildBinaryClassifiers(int[] attrs, double[][] x, double[] y, int maxNumIters, double[] lambdas) { double[] w = new double[attrs.length]; double intercept = 0; double[] pTrain = new double[y.length]; double[] rTrain = new double[y.length]; OptimUtils.computePseudoResidual(pTrain, y, rTrain); // Calculate theta's double[] theta = new double[x.length]; for (int i = 0; i < x.length; i++) { theta[i] = StatUtils.sumSq(x[i]) / 4; } GLM[] glms = new GLM[lambdas.length]; Arrays.sort(lambdas); for (int g = 0; g < glms.length; g++) { double lambda = lambdas[g]; // Coordinate gradient descent final double tl2 = lambda * y.length; for (int iter = 0; iter < maxNumIters; iter++) { double prevLoss = GLMOptimUtils.computeRidgeLoss(pTrain, y, w, lambda); if (fitIntercept) { intercept += OptimUtils.fitIntercept(pTrain, rTrain, y); } doOnePassBinomial(x, theta, y, tl2, w, pTrain, rTrain); double currLoss = GLMOptimUtils.computeRidgeLoss(pTrain, y, w, lambda); if (verbose) { System.out.println("Iteration " + iter + ": " + " " + currLoss); } if (OptimUtils.isConverged(prevLoss, currLoss, epsilon)) { break; } } glms[g] = GLMOptimUtils.getGLM(attrs, w, intercept, LinkFunction.LOGIT); } return glms; } /** * Builds L2-regularized binary classifiers for a sequence of regularization parameter lambdas on sparse format. * Each row of the input represents a feature (instead of a data point), i.e., in column-oriented format. This * procedure does not assume the data is normalized or centered. * * @param attrs the attribute list. * @param indices the indices. * @param values the values. * @param y the targets. * @param maxNumIters the maximum number of iterations. * @param lambdas the lambdas array. * @return L2-regularized classifiers. */ public GLM[] buildBinaryClassifiers(int[] attrs, int[][] indices, double[][] values, double[] y, int maxNumIters, double[] lambdas) { double[] w = new double[attrs.length]; double intercept = 0; double[] pTrain = new double[y.length]; double[] rTrain = new double[y.length]; OptimUtils.computePseudoResidual(pTrain, y, rTrain); // Calculate theta's double[] theta = new double[values.length]; for (int i = 0; i < values.length; i++) { theta[i] = StatUtils.sumSq(values[i]) / 4; } GLM[] glms = new GLM[lambdas.length]; Arrays.sort(lambdas); for (int g = 0; g < glms.length; g++) { double lambda = lambdas[g]; // Coordinate gradient descent final double tl2 = lambda * y.length; for (int iter = 0; iter < maxNumIters; iter++) { double prevLoss = GLMOptimUtils.computeRidgeLoss(pTrain, y, w, lambda); if (fitIntercept) { intercept += OptimUtils.fitIntercept(pTrain, rTrain, y); } doOnePassBinomial(indices, values, theta, y, tl2, w, pTrain, rTrain); double currLoss = GLMOptimUtils.computeRidgeLoss(pTrain, y, w, lambda); if (verbose) { System.out.println("Iteration " + iter + ": " + " " + currLoss); } if (OptimUtils.isConverged(prevLoss, currLoss, epsilon)) { break; } } glms[g] = GLMOptimUtils.getGLM(attrs, w, intercept, LinkFunction.LOGIT); } return glms; } /** * Builds an L2-regularized binary classifier. * * @param trainSet the training set. * @param isSparse {@code true} if the training set is treated as sparse. * @param maxNumIters the maximum number of iterations. * @param lambda the lambda. * @return an L2-regularized binary classifier. */ public GLM buildClassifier(Instances trainSet, boolean isSparse, int maxNumIters, double lambda) { Attribute classAttribute = trainSet.getTargetAttribute(); if (classAttribute.getType() != Attribute.Type.NOMINAL) { throw new IllegalArgumentException("Class attribute must be nominal."); } NominalAttribute clazz = (NominalAttribute) classAttribute; int numClasses = clazz.getStates().length; if (isSparse) { SparseDataset sd = getSparseDataset(trainSet, true); int[] attrs = sd.attrs; int[][] indices = sd.indices; double[][] values = sd.values; double[] y = new double[sd.y.length]; double[] cList = sd.cList; if (numClasses == 2) { for (int i = 0; i < y.length; i++) { int label = (int) sd.y[i]; y[i] = label == 0 ? 1 : 0; } GLM glm = buildBinaryClassifier(attrs, indices, values, y, maxNumIters, lambda); double[] w = glm.w[0]; for (int j = 0; j < cList.length; j++) { int attIndex = attrs[j]; w[attIndex] *= cList[j]; } return glm; } else { int p = attrs.length == 0 ? 0 : (StatUtils.max(attrs) + 1); GLM glm = new GLM(numClasses, p); for (int k = 0; k < numClasses; k++) { // One-vs-the-rest for (int i = 0; i < y.length; i++) { int label = (int) sd.y[i]; y[i] = label == k ? 1 : 0; } GLM binaryClassifier = buildBinaryClassifier(attrs, indices, values, y, maxNumIters, lambda); double[] w = binaryClassifier.w[0]; for (int j = 0; j < cList.length; j++) { int attIndex = attrs[j]; glm.w[k][attIndex] = w[attIndex] * cList[j]; } glm.intercept[k] = binaryClassifier.intercept[0]; } return glm; } } else { DenseDataset dd = getDenseDataset(trainSet, true); int[] attrs = dd.attrs; double[][] x = dd.x; double[] y = new double[dd.y.length]; double[] cList = dd.cList; if (numClasses == 2) { for (int i = 0; i < y.length; i++) { int label = (int) dd.y[i]; y[i] = label == 0 ? 1 : 0; } GLM glm = buildBinaryClassifier(attrs, x, y, maxNumIters, lambda); double[] w = glm.w[0]; for (int j = 0; j < cList.length; j++) { int attIndex = attrs[j]; w[attIndex] *= cList[j]; } return glm; } else { int p = attrs.length == 0 ? 0 : attrs[attrs.length - 1] + 1; GLM glm = new GLM(numClasses, p); for (int k = 0; k < numClasses; k++) { // One-vs-the-rest for (int i = 0; i < y.length; i++) { int label = (int) dd.y[i]; y[i] = label == k ? 1 : 0; } GLM binaryClassifier = buildBinaryClassifier(attrs, x, y, maxNumIters, lambda); double[] w = binaryClassifier.w[0]; for (int j = 0; j < cList.length; j++) { int attIndex = attrs[j]; glm.w[k][attIndex] = w[attIndex] * cList[j]; } glm.intercept[k] = binaryClassifier.intercept[0]; } return glm; } } } /** * Builds an L2-regularized classifier. * * @param trainSet the training set. * @param maxNumIters the maximum number of iterations. * @param lambda the lambda. * @return an L2-regularized classifier. */ public GLM buildClassifier(Instances trainSet, int maxNumIters, double lambda) { return buildClassifier(trainSet, isSparse(trainSet), maxNumIters, lambda); } /** * Builds L2-regularized classifiers for a sequence of regularization parameter lambdas. * * @param trainSet the training set. * @param isSparse {@code true} if the training set is treated as sparse. * @param maxNumIters the maximum number of iterations. * @param lambdas the lambdas array. * @return L2-regularized binary classifiers. */ public GLM[] buildClassifiers(Instances trainSet, boolean isSparse, int maxNumIters, double[] lambdas) { Attribute classAttribute = trainSet.getTargetAttribute(); if (classAttribute.getType() != Attribute.Type.NOMINAL) { throw new IllegalArgumentException("Class attribute must be nominal."); } NominalAttribute clazz = (NominalAttribute) classAttribute; int numClasses = clazz.getStates().length; if (isSparse) { SparseDataset sd = getSparseDataset(trainSet, true); int[] attrs = sd.attrs; int[][] indices = sd.indices; double[][] values = sd.values; double[] y = new double[sd.y.length]; double[] cList = sd.cList; if (numClasses == 2) { for (int i = 0; i < y.length; i++) { int label = (int) sd.y[i]; y[i] = label == 0 ? 1 : 0; } GLM[] glms = buildBinaryClassifiers(attrs, indices, values, y, maxNumIters, lambdas); for (GLM glm : glms) { double[] w = glm.w[0]; for (int j = 0; j < cList.length; j++) { int attIndex = attrs[j]; w[attIndex] *= cList[j]; } } return glms; } else { int p = attrs.length == 0 ? 0 : (StatUtils.max(attrs) + 1); GLM[] glms = new GLM[lambdas.length]; for (int i = 0; i < glms.length; i++) { glms[i] = new GLM(numClasses, p); } for (int k = 0; k < numClasses; k++) { // One-vs-the-rest for (int i = 0; i < y.length; i++) { int label = (int) sd.y[i]; y[i] = label == k ? 1 : 0; } GLM[] binaryClassifiers = buildBinaryClassifiers(attrs, indices, values, y, maxNumIters, lambdas); for (int l = 0; l < glms.length; l++) { GLM binaryClassifier = binaryClassifiers[l]; GLM glm = glms[l]; double[] w = binaryClassifier.w[0]; for (int j = 0; j < cList.length; j++) { int attIndex = attrs[j]; glm.w[k][attIndex] = w[attIndex] * cList[j]; } glm.intercept[k] = binaryClassifier.intercept[0]; } } return glms; } } else { DenseDataset dd = getDenseDataset(trainSet, true); int[] attrs = dd.attrs; double[][] x = dd.x; double[] y = new double[dd.y.length]; double[] cList = dd.cList; if (numClasses == 2) { for (int i = 0; i < y.length; i++) { int label = (int) dd.y[i]; y[i] = label == 0 ? 1 : 0; } GLM[] glms = buildBinaryClassifiers(attrs, x, y, maxNumIters, lambdas); for (GLM glm : glms) { double[] w = glm.w[0]; for (int j = 0; j < cList.length; j++) { int attIndex = attrs[j]; w[attIndex] *= cList[j]; } } return glms; } else { int p = attrs.length == 0 ? 0 : attrs[attrs.length - 1] + 1; GLM[] glms = new GLM[lambdas.length]; for (int i = 0; i < glms.length; i++) { glms[i] = new GLM(numClasses, p); } for (int k = 0; k < numClasses; k++) { // One-vs-the-rest for (int i = 0; i < y.length; i++) { int label = (int) dd.y[i]; y[i] = label == k ? 1 : 0; } GLM[] binaryClassifiers = buildBinaryClassifiers(attrs, x, y, maxNumIters, lambdas); for (int l = 0; l < glms.length; l++) { GLM binaryClassifier = binaryClassifiers[l]; GLM glm = glms[l]; double[] w = binaryClassifier.w[0]; for (int j = 0; j < cList.length; j++) { int attIndex = attrs[j]; glm.w[k][attIndex] = w[attIndex] * cList[j]; } glm.intercept[k] = binaryClassifier.intercept[0]; } } return glms; } } } /** * Builds L2-regularized classifiers for a sequence of regularization parameter lambdas. * * @param trainSet the training set. * @param maxNumIters the maximum number of iterations. * @param lambdas the lambdas array. * @return L2-regularized binary classifiers. */ public GLM[] buildClassifiers(Instances trainSet, int maxNumIters, double[] lambdas) { return buildClassifiers(trainSet, isSparse(trainSet), maxNumIters, lambdas); } /** * Builds an L2 regressor. * * @param trainSet the training set. * @param isSparse {@code true} if the training set is treated as sparse. * @param maxNumIters the maximum number of iterations. * @param lambda the lambda. * @return an L2-regularized regressor. */ public GLM buildGaussianRegressor(Instances trainSet, boolean isSparse, int maxNumIters, double lambda) { if (isSparse) { SparseDataset sd = getSparseDataset(trainSet, true); double[] cList = sd.cList; GLM glm = buildGaussianRegressor(sd.attrs, sd.indices, sd.values, sd.y, maxNumIters, lambda); double[] w = glm.w[0]; for (int j = 0; j < cList.length; j++) { int attIndex = sd.attrs[j]; w[attIndex] *= cList[j]; } return glm; } else { DenseDataset dd = getDenseDataset(trainSet, true); double[] cList = dd.cList; GLM glm = buildGaussianRegressor(dd.attrs, dd.x, dd.y, maxNumIters, lambda); double[] w = glm.w[0]; for (int j = 0; j < cList.length; j++) { int attIndex = dd.attrs[j]; w[attIndex] *= cList[j]; } return glm; } } /** * Builds an L2-regularized regressor. * * @param trainSet the training set. * @param maxNumIters the maximum number of iterations. * @param lambda the lambda. * @return an L2-regularized regressor. */ public GLM buildGaussianRegressor(Instances trainSet, int maxNumIters, double lambda) { return buildGaussianRegressor(trainSet, isSparse(trainSet), maxNumIters, lambda); } /** * Builds an L2-regularized regressor. Each row in the input matrix x represents a feature (instead of a data * point). Thus the input matrix is the transpose of the row-oriented data matrix. This procedure does not assume * the data is normalized or centered. * * @param attrs the attribute list. * @param x the inputs. * @param y the targets. * @param maxNumIters the maximum number of iterations. * @param lambda the lambda. * @return an L2-regularized regressor. */ public GLM buildGaussianRegressor(int[] attrs, double[][] x, double[] y, int maxNumIters, double lambda) { double[] w = new double[attrs.length]; double intercept = 0; // Initialize residuals double[] rTrain = new double[y.length]; for (int i = 0; i < rTrain.length; i++) { rTrain[i] = y[i]; } // Calculate sum of squares double[] sq = new double[x.length]; for (int i = 0; i < x.length; i++) { sq[i] = StatUtils.sumSq(x[i]); } // Coordinate descent final double tl2 = lambda * y.length; for (int iter = 0; iter < maxNumIters; iter++) { double prevLoss = GLMOptimUtils.computeRidgeLoss(rTrain, w, lambda); if (fitIntercept) { intercept += OptimUtils.fitIntercept(rTrain); } doOnePassGaussian(x, sq, tl2, w, rTrain); double currLoss = GLMOptimUtils.computeRidgeLoss(rTrain, w, lambda); if (verbose) { System.out.println("Iteration " + iter + ": " + " " + currLoss); } if (OptimUtils.isConverged(prevLoss, currLoss, epsilon)) { break; } } return GLMOptimUtils.getGLM(attrs, w, intercept, LinkFunction.IDENTITY); } /** * Builds an L2-regularized regressor on sparse inputs. Each row of the input represents a feature (instead of a * data point), i.e., in column-oriented format. This procedure does not assume the data is normalized or centered. * * @param attrs the attribute list. * @param indices the indices. * @param values the values. * @param y the targets. * @param maxNumIters the maximum number of iterations. * @param lambda the lambda. * @return an L2-regularized regressor. */ public GLM buildGaussianRegressor(int[] attrs, int[][] indices, double[][] values, double[] y, int maxNumIters, double lambda) { double[] w = new double[attrs.length]; double intercept = 0; // Initialize residuals double[] rTrain = new double[y.length]; for (int i = 0; i < rTrain.length; i++) { rTrain[i] = y[i]; } // Calculate sum of squares double[] sq = new double[attrs.length]; for (int i = 0; i < values.length; i++) { sq[i] = StatUtils.sumSq(values[i]); } // Coordinate descent final double tl2 = lambda * y.length; for (int iter = 0; iter < maxNumIters; iter++) { double prevLoss = GLMOptimUtils.computeRidgeLoss(rTrain, w, lambda); if (fitIntercept) { intercept += OptimUtils.fitIntercept(rTrain); } doOnePassGaussian(indices, values, sq, tl2, w, rTrain); double currLoss = GLMOptimUtils.computeRidgeLoss(rTrain, w, lambda); if (verbose) { System.out.println("Iteration " + iter + ": " + " " + currLoss); } if (OptimUtils.isConverged(prevLoss, currLoss, epsilon)) { break; } } return GLMOptimUtils.getGLM(attrs, w, intercept, LinkFunction.IDENTITY); } /** * Builds L2-regularized regressors for a sequence of regularization parameter lambdas. * * @param trainSet the training set. * @param isSparse {@code true} if the training set is treated as sparse. * @param maxNumIters the maximum number of iterations. * @param lambdas the lambdas array. * @return L2-regularized regressors. */ public GLM[] buildGaussianRegressors(Instances trainSet, boolean isSparse, int maxNumIters, double[] lambdas) { if (isSparse) { SparseDataset sd = getSparseDataset(trainSet, true); double[] cList = sd.cList; GLM[] glms = buildGaussianRegressors(sd.attrs, sd.indices, sd.values, sd.y, maxNumIters, lambdas); for (GLM glm : glms) { double[] w = glm.w[0]; for (int j = 0; j < cList.length; j++) { int attIndex = sd.attrs[j]; w[attIndex] *= cList[j]; } } return glms; } else { DenseDataset dd = getDenseDataset(trainSet, true); double[] cList = dd.cList; GLM[] glms = buildGaussianRegressors(dd.attrs, dd.x, dd.y, maxNumIters, lambdas); for (GLM glm : glms) { double[] w = glm.w[0]; for (int j = 0; j < cList.length; j++) { int attIndex = dd.attrs[j]; w[attIndex] *= cList[j]; } } return glms; } } /** * Builds L2-regularized regressors for a sequence of regularization parameter lambdas. * * @param trainSet the training set. * @param maxNumIters the maximum number of iterations. * @param lambdas the lambdas array. * @return L2-regularized regressors. */ public GLM[] buildGaussianRegressors(Instances trainSet, int maxNumIters, double[] lambdas) { return buildGaussianRegressors(trainSet, isSparse(trainSet), maxNumIters, lambdas); } /** * Builds L2-regularized regressors for a sequence of regularization parameter lambdas on dense inputs. Each row in * the input matrix x represents a feature (instead of a data point). Thus the input matrix is the transpose of the * row-oriented data matrix. This procedure does not assume the data is normalized or centered. * * @param attrs the attribute list. * @param x the inputs. * @param y the targets. * @param maxNumIters the maximum number of iterations. * @param lambdas the lambdas array. * @return L2-regularized regressors. */ public GLM[] buildGaussianRegressors(int[] attrs, double[][] x, double[] y, int maxNumIters, double[] lambdas) { double[] w = new double[attrs.length]; double intercept = 0; GLM[] glms = new GLM[lambdas.length]; Arrays.sort(lambdas); // Backup targets double[] rTrain = new double[y.length]; for (int i = 0; i < rTrain.length; i++) { rTrain[i] = y[i]; } // Calculate sum of squares double[] sq = new double[x.length]; for (int i = 0; i < x.length; i++) { sq[i] = StatUtils.sumSq(x[i]); } // Compute the regularization path for (int g = 0; g < glms.length; g++) { double lambda = lambdas[g]; // Coordinate descent final double tl2 = lambda * y.length; for (int iter = 0; iter < maxNumIters; iter++) { double prevLoss = GLMOptimUtils.computeRidgeLoss(rTrain, w, lambda); if (fitIntercept) { intercept += OptimUtils.fitIntercept(rTrain); } doOnePassGaussian(x, sq, tl2, w, rTrain); double currLoss = GLMOptimUtils.computeRidgeLoss(rTrain, w, lambda); if (verbose) { System.out.println("Iteration " + iter + ": " + " " + currLoss); } if (OptimUtils.isConverged(prevLoss, currLoss, epsilon)) { break; } } glms[g] = GLMOptimUtils.getGLM(attrs, w, intercept, LinkFunction.IDENTITY); } return glms; } /** * Builds L2-regularized regressors for a sequence of regularization parameter lambdas on sparse inputs. Each row of * the input represents a feature (instead of a data point), i.e., in column-oriented format. This procedure does * not assume the data is normalized or centered. * * @param attrs the attribute list. * @param indices the indices. * @param values the values. * @param y the targets. * @param maxNumIters the maximum number of iterations. * @param lambdas the lambdas array. * @return L2-regularized regressors. */ public GLM[] buildGaussianRegressors(int[] attrs, int[][] indices, double[][] values, double[] y, int maxNumIters, double[] lambdas) { double[] w = new double[attrs.length]; double intercept = 0; GLM[] glms = new GLM[lambdas.length]; Arrays.sort(lambdas); // Backup targets double[] rTrain = new double[y.length]; for (int i = 0; i < rTrain.length; i++) { rTrain[i] = y[i]; } // Calculate sum of squares double[] sq = new double[attrs.length]; for (int i = 0; i < values.length; i++) { sq[i] = StatUtils.sumSq(values[i]); } // Compute the regularization path for (int g = 0; g < glms.length; g++) { double lambda = lambdas[g]; // Coordinate descent final double tl2 = lambda * y.length; for (int iter = 0; iter < maxNumIters; iter++) { double prevLoss = GLMOptimUtils.computeRidgeLoss(rTrain, w, lambda); if (fitIntercept) { intercept += OptimUtils.fitIntercept(rTrain); } doOnePassGaussian(indices, values, sq, tl2, w, rTrain); double currLoss = GLMOptimUtils.computeRidgeLoss(rTrain, w, lambda); if (verbose) { System.out.println("Iteration " + iter + ": " + " " + currLoss); } if (OptimUtils.isConverged(prevLoss, currLoss, epsilon)) { break; } } glms[g] = GLMOptimUtils.getGLM(attrs, w, intercept, LinkFunction.IDENTITY); } return glms; } protected void doOnePassGaussian(double[][] x, double[] sq, final double tl2, double[] w, double[] rTrain) { for (int j = 0; j < x.length; j++) { double[] v = x[j]; // Calculate weight updates using naive updates double eta = VectorUtils.dotProduct(rTrain, v); double wNew = (w[j] * sq[j] + eta) / (sq[j] + tl2); double delta = wNew - w[j]; w[j] = wNew; // Update residuals for (int i = 0; i < rTrain.length; i++) { rTrain[i] -= delta * v[i]; } } } protected void doOnePassGaussian(int[][] indices, double[][] values, double[] sq, final double tl2, double[] w, double[] rTrain) { for (int j = 0; j < indices.length; j++) { // Calculate weight updates using naive updates double wNew = w[j] * sq[j]; int[] index = indices[j]; double[] value = values[j]; for (int i = 0; i < index.length; i++) { wNew += rTrain[index[i]] * value[i]; } wNew /= (sq[j] + tl2); double delta = wNew - w[j]; w[j] = wNew; // Update residuals for (int i = 0; i < index.length; i++) { rTrain[index[i]] -= delta * value[i]; } } } protected void doOnePassBinomial(double[][] x, double[] theta, double[] y, final double tl2, double[] w, double[] pTrain, double[] rTrain) { for (int j = 0; j < x.length; j++) { if (Math.abs(theta[j]) <= MathUtils.EPSILON) { continue; } double[] v = x[j]; double eta = VectorUtils.dotProduct(rTrain, v); double wNew = (w[j] * theta[j] + eta) / (theta[j] + tl2); double delta = wNew - w[j]; w[j] = wNew; // Update predictions for (int i = 0; i < pTrain.length; i++) { pTrain[i] += delta * v[i]; rTrain[i] = OptimUtils.getPseudoResidual(pTrain[i], y[i]); } } } protected void doOnePassBinomial(int[][] indices, double[][] values, double[] theta, double[] y, final double tl2, double[] w, double[] pTrain, double[] rTrain) { for (int j = 0; j < indices.length; j++) { if (Math.abs(theta[j]) <= MathUtils.EPSILON) { continue; } double eta = 0; int[] index = indices[j]; double[] value = values[j]; for (int i = 0; i < index.length; i++) { int idx = index[i]; eta += rTrain[idx] * value[i]; } double wNew = (w[j] * theta[j] + eta) / (theta[j] + tl2); double delta = wNew - w[j]; w[j] = wNew; // Update predictions for (int i = 0; i < index.length; i++) { int idx = index[i]; pTrain[idx] += delta * value[i]; rTrain[idx] = OptimUtils.getPseudoResidual(pTrain[idx], y[idx]); } } } /** * Returns the lambda. * * @return the lambda. */ public double getLambda() { return lambda; } /** * Returns the task of this learner. * * @return the task of this learner. */ public Task getTask() { return task; } /** * Sets the lambda. * * @param lambda the lambda. */ public void setLambda(double lambda) { this.lambda = lambda; } /** * Sets the task of this learner. * * @param task the task of this learner. */ public void setTask(Task task) { this.task = task; } } ================================================ FILE: src/main/java/mltk/predictor/glm/package-info.java ================================================ /** * Provides algorithms for fitting generalized linear models (GLMs). */ package mltk.predictor.glm; ================================================ FILE: src/main/java/mltk/predictor/io/PredictorReader.java ================================================ package mltk.predictor.io; import java.io.BufferedReader; import java.io.FileReader; import mltk.predictor.Predictor; /** * Class for reading predictors. * * @author Yin Lou * */ public class PredictorReader { /** * Reads a predictor. The caller is responsible for converting the predictor to correct type. * * @param path the file path for the predictor. * @return the parsed predictor. * @throws Exception */ public static Predictor read(String path) throws Exception { BufferedReader in = new BufferedReader(new FileReader(path)); String line = in.readLine(); String predictorName = line.substring(1, line.length() - 1).split(": ")[1]; Class clazz = Class.forName(predictorName); Predictor predictor = (Predictor) clazz.getDeclaredConstructor().newInstance(); predictor.read(in); in.close(); return predictor; } /** * Reads a predictor. The caller is responsible for providing the correct predictor type. * * @param path the file path for the predictor. * @param clazz the class of the predictor. * @param the type of the predictor class. * @return the parsed predictor. * @throws Exception */ public static T read(String path, Class clazz) throws Exception { Predictor predictor = read(path); return clazz.cast(predictor); } /** * Reads a predictor from an input reader. The caller is responsible for converting the predictor to correct type. * * @param in the input reader. * @return the parsed predictor. * @throws Exception */ public static Predictor read(BufferedReader in) throws Exception { String line = in.readLine(); String predictorName = line.substring(1, line.length() - 1).split(": ")[1]; Class clazz = Class.forName(predictorName); Predictor predictor = (Predictor) clazz.getDeclaredConstructor().newInstance(); predictor.read(in); return predictor; } /** * Reads a predictor from an input reader. The caller is responsible for providing the correct predictor type. * * @param in the input reader. * @param clazz the class of the predictor. * @param the type of the predictor class. * @return the parsed predictor. * @throws Exception */ public static T read(BufferedReader in, Class clazz) throws Exception { Predictor predictor = read(in); return clazz.cast(predictor); } } ================================================ FILE: src/main/java/mltk/predictor/io/PredictorWriter.java ================================================ package mltk.predictor.io; import java.io.PrintWriter; import mltk.predictor.Predictor; /** * Class for writing predictors. * * @author Yin Lou * */ public class PredictorWriter { /** * Writes a predictor to file. * * @param predictor the predictor to write. * @param path the file path. * @throws Exception */ public static void write(Predictor predictor, String path) throws Exception { PrintWriter out = new PrintWriter(path); predictor.write(out); out.flush(); out.close(); } } ================================================ FILE: src/main/java/mltk/predictor/io/package-info.java ================================================ /** * Provides classes for reading and writing predictors. */ package mltk.predictor.io; ================================================ FILE: src/main/java/mltk/predictor/package-info.java ================================================ /** * Provides interfaces and classes for predictors. */ package mltk.predictor; ================================================ FILE: src/main/java/mltk/predictor/tree/DecisionTable.java ================================================ package mltk.predictor.tree; import java.io.BufferedReader; import java.io.PrintWriter; import java.util.Arrays; import mltk.core.Instance; import mltk.util.ArrayUtils; import mltk.util.VectorUtils; /** * Class for decision tables. * * @author Yin Lou * */ public class DecisionTable implements RTree { protected int[] attIndices; protected double[] splits; protected long[] predIndices; protected double[] predValues; /** * Constructor. */ public DecisionTable() { } /** * Constructor. * * @param attIndices the attribute indices. * @param splits the splits. * @param predIndices the prediction indices. * @param predValues the prediction values. */ public DecisionTable(int[] attIndices, double[] splits, long[] predIndices, double[] predValues) { this.attIndices = attIndices; this.splits = splits; this.predIndices = predIndices; this.predValues = predValues; } /** * Returns the attribute indices in this tree. * * @return the attribute indices in this tree. */ public int[] getAttributeIndices() { return attIndices; } /** * Returns the splits in this tree. * * @return the splits in this tree. */ public double[] getSplits() { return splits; } @Override public void multiply(double c) { VectorUtils.multiply(predValues, c); } @Override public void read(BufferedReader in) throws Exception { in.readLine(); attIndices = ArrayUtils.parseIntArray(in.readLine()); in.readLine(); splits = ArrayUtils.parseDoubleArray(in.readLine()); in.readLine(); predIndices = ArrayUtils.parseLongArray(in.readLine()); in.readLine(); predValues = ArrayUtils.parseDoubleArray(in.readLine()); } @Override public void write(PrintWriter out) throws Exception { out.printf("[Predictor: %s]\n", this.getClass().getCanonicalName()); out.println("Attributes: " + attIndices.length); out.println(Arrays.toString(attIndices)); out.println("Splits: " + splits.length); out.println(Arrays.toString(splits)); out.println("Prediction Indices: " + predIndices.length); out.println(Arrays.toString(predIndices)); out.println("Prediction Values: " + predValues.length); out.println(Arrays.toString(predValues)); } @Override public DecisionTable copy() { int[] attIndicesCopy = Arrays.copyOf(attIndices, attIndices.length); double[] splitsCopy = Arrays.copyOf(splits, splits.length); long[] predIndicesCopy = Arrays.copyOf(predIndices, predIndices.length); double[] predValuesCopy = Arrays.copyOf(predValues, predValues.length); return new DecisionTable(attIndicesCopy, splitsCopy, predIndicesCopy, predValuesCopy); } @Override public double regress(Instance instance) { long predIdx = 0L; for (int j = 0; j < attIndices.length; j++) { int attIndex = attIndices[j]; double split = splits[j]; if (instance.getValue(attIndex) <= split) { predIdx = (predIdx << 1) | 1L; } else { predIdx <<= 1; } } return regress(predIdx); } /** * Returns the prediction based on prediction index. * * @param predIdx * @return the prediction based on prediction index. */ public double regress(long predIdx) { int idx = Arrays.binarySearch(predIndices, predIdx); if (idx < 0) { return 0; } else { return predValues[idx]; } } } ================================================ FILE: src/main/java/mltk/predictor/tree/DecisionTableLearner.java ================================================ package mltk.predictor.tree; import java.util.ArrayList; import java.util.Collections; import java.util.HashMap; import java.util.HashSet; import java.util.List; import java.util.Map; import java.util.Set; import mltk.core.Attribute; import mltk.core.Attribute.Type; import mltk.core.BinnedAttribute; import mltk.core.Instance; import mltk.core.Instances; import mltk.core.NominalAttribute; import mltk.util.ArrayUtils; import mltk.util.OptimUtils; import mltk.util.Random; import mltk.util.StatUtils; import mltk.util.tuple.DoublePair; import mltk.util.tuple.IntDoublePair; import mltk.util.tuple.LongDoublePair; import mltk.util.tuple.LongDoublePairComparator; /** * Class for learning decision tables. * *

* Reference:
* Y. Lou and M. Obukhov. BDT: Boosting Decision Tables for High Accuracy and Scoring Efficiency. In Proceedings of the * 23rd ACM SIGKDD International Conference on Knowledge Discovery and Data Mining (KDD), Halifax, Nova Scotia, Canada, 2017. *

* * This class has a different implementation to better fit the design of this package. * * @author Yin Lou * */ public class DecisionTableLearner extends RTreeLearner { /** * Enumeration of construction mode. * * @author Yin Lou * */ public enum Mode { ONE_PASS_GREEDY, MULTI_PASS_CYCLIC, MULTI_PASS_RANDOM; } protected Mode mode; protected int maxDepth; protected int numPasses; /** * Constructor. */ public DecisionTableLearner() { mode = Mode.ONE_PASS_GREEDY; maxDepth = 6; numPasses = 2; } @Override public void setParameters(String mode) { String[] data = mode.split(":"); if (data.length != 2) { throw new IllegalArgumentException(); } this.setMaxDepth(Integer.parseInt(data[1])); switch (data[0]) { case "g": this.setConstructionMode(Mode.ONE_PASS_GREEDY); break; case "c": this.setConstructionMode(Mode.MULTI_PASS_CYCLIC); this.setNumPasses(2); break; case "r": this.setConstructionMode(Mode.MULTI_PASS_RANDOM); this.setNumPasses(2); break; default: throw new IllegalArgumentException(); } } @Override public boolean isRobust() { return false; } /** * Returns the construction mode. * * @return the construction mode. */ public Mode getConstructionMode() { return mode; } /** * Sets the construction mode. * * @param mode the construction mode. */ public void setConstructionMode(Mode mode) { this.mode = mode; } /** * Returns the maximum depth. * * @return the maximum depth. */ public int getMaxDepth() { return maxDepth; } /** * Sets the maximum depth. * * @param maxDepth the maximum depth. */ public void setMaxDepth(int maxDepth) { this.maxDepth = maxDepth; } /** * Returns the number of passes. This parameter is used in multi-pass cyclic mode. * * @return the number of passes. */ public int getNumPasses() { return numPasses; } /** * Sets the number of passes. * * @param numPasses the number of passes. */ public void setNumPasses(int numPasses) { this.numPasses = numPasses; } @Override public DecisionTable build(Instances instances) { DecisionTable ot = null; switch (mode) { case ONE_PASS_GREEDY: ot = buildOnePassGreedy(instances, maxDepth); break; case MULTI_PASS_CYCLIC: ot = buildMultiPassCyclic(instances, maxDepth, numPasses); break; case MULTI_PASS_RANDOM: ot = buildMultiPassRandom(instances, maxDepth, numPasses); default: break; } return ot; } /** * Builds a standard oblivious regression tree using greedy tree induction. * * @param instances the training set. * @param maxDepth the maximum depth. * @return an oblivious regression tree. */ public DecisionTable buildOnePassGreedy(Instances instances, int maxDepth) { // stats[0]: totalWeights // stats[1]: sum // stats[2]: weightedMean double[] stats = new double[3]; Map map = new HashMap<>(instances.size()); List attList = new ArrayList<>(maxDepth); List splitList = new ArrayList<>(maxDepth); Dataset dataset = null; if (this.cache != null) { dataset = Dataset.create(this.cache, instances); } else { dataset = Dataset.create(instances); } map.put(Long.valueOf(0L), dataset); if (maxDepth <= 0) { getStats(dataset.instances, stats); final double weightedMean = stats[2]; return new DecisionTable( new int[] {}, new double[] {}, new long[] { 0L }, new double[] { weightedMean }); } List attributes = instances.getAttributes(); List> featureValues = new ArrayList<>(attributes.size()); for (int j = 0; j < attributes.size(); j++) { Attribute attribute = attributes.get(j); List values = new ArrayList<>(); if (attribute.getType() == Type.BINNED) { int numBins = ((BinnedAttribute) attribute).getNumBins(); for (int i = 0; i < numBins; i++) { values.add((double) i); } } else if (attribute.getType() == Type.NOMINAL) { int cardinality = ((NominalAttribute) attribute).getCardinality(); for (int i = 0; i < cardinality; i++) { values.add((double) i); } } else { Set set = new HashSet<>(); for (Instance instance : instances) { set.add(instance.getValue(attribute)); } values.addAll(set); Collections.sort(values); } featureValues.add(values); } for (int d = 0; d < maxDepth; d++) { double bestGain = Double.NEGATIVE_INFINITY; List splitCandidates = new ArrayList<>(); for (int j = 0; j < attributes.size(); j++) { List values = featureValues.get(j); if (values.size() <= 1) { continue; } Attribute attribute = attributes.get(j); int attIndex = attribute.getIndex(); String attName = attribute.getName(); double[] gains = new double[values.size() - 1]; for (Dataset data : map.values()) { getStats(data.instances, stats); final double totalWeights = stats[0]; final double sum = stats[1]; List sortedList = data.sortedLists.get(attName); List uniqueValues = new ArrayList<>(sortedList.size()); List histogram = new ArrayList<>(sortedList.size()); getHistogram(data.instances, sortedList, uniqueValues, totalWeights, sum, histogram); double[] localGains = evalSplits(uniqueValues, histogram, totalWeights, sum); processGains(uniqueValues, localGains, values, gains); } int idx = StatUtils.indexOfMax(gains); if (bestGain <= gains[idx]) { double split = (values.get(idx) + values.get(idx + 1)) / 2; if (bestGain < gains[idx]) { bestGain = gains[idx]; splitCandidates.clear(); } splitCandidates.add(new IntDoublePair(attIndex, split)); } } if (splitCandidates.size() == 0) { break; } Random rand = Random.getInstance(); IntDoublePair split = splitCandidates.get(rand.nextInt(splitCandidates.size())); attList.add(split.v1); splitList.add(split.v2); Map mapNew = new HashMap<>(map.size() * 2); for (Map.Entry entry : map.entrySet()) { Long key = entry.getKey(); Dataset data = entry.getValue(); Dataset left = new Dataset(data.instances); Dataset right = new Dataset(data.instances); data.split(split.v1, split.v2, left, right); if (left.instances.size() > 0) { Long leftKey = (key << 1) | 1L; mapNew.put(leftKey, left); } if (right.instances.size() > 0) { Long rightKey = key << 1; mapNew.put(rightKey, right); } } map = mapNew; } int[] attIndices = ArrayUtils.toIntArray(attList); double[] splits = ArrayUtils.toDoubleArray(splitList); List list = new ArrayList<>(splits.length); for (Map.Entry entry : map.entrySet()) { Long key = entry.getKey(); Dataset data = entry.getValue(); getStats(data.instances, stats); list.add(new LongDoublePair(key, stats[2])); } Collections.sort(list, new LongDoublePairComparator()); long[] predIndices = new long[list.size()]; double[] predValues = new double[list.size()]; for (int i = 0; i < predIndices.length; i++) { LongDoublePair pair = list.get(i); predIndices[i] = pair.v1; predValues[i] = pair.v2; } return new DecisionTable(attIndices, splits, predIndices, predValues); } /** * Builds an oblivious regression tree using multi-pass cyclic backfitting. * * @param instances the training set. * @param maxDepth the maximum depth. * @param numPasses the number of passes. * @return an oblivious regression tree. */ public DecisionTable buildMultiPassCyclic(Instances instances, int maxDepth, int numPasses) { // stats[0]: totalWeights // stats[1]: sum // stats[2]: weightedMean double[] stats = new double[3]; Map map = new HashMap<>(instances.size()); int[] attIndices = new int[maxDepth]; double[] splits = new double[maxDepth]; Dataset dataset = null; if (this.cache != null) { dataset = Dataset.create(this.cache, instances); } else { dataset = Dataset.create(instances); } map.put(Long.valueOf(0L), dataset); if (maxDepth <= 0) { getStats(dataset.instances, stats); final double weightedMean = stats[2]; return new DecisionTable( new int[] {}, new double[] {}, new long[] { 0L }, new double[] { weightedMean }); } List attributes = instances.getAttributes(); List> featureValues = new ArrayList<>(attributes.size()); for (int j = 0; j < attributes.size(); j++) { Attribute attribute = attributes.get(j); List values = new ArrayList<>(); if (attribute.getType() == Type.BINNED) { int numBins = ((BinnedAttribute) attribute).getNumBins(); for (int i = 0; i < numBins; i++) { values.add((double) i); } } else if (attribute.getType() == Type.NOMINAL) { int cardinality = ((NominalAttribute) attribute).getCardinality(); for (int i = 0; i < cardinality; i++) { values.add((double) i); } } else { Set set = new HashSet<>(); for (Instance instance : instances) { set.add(instance.getValue(attribute)); } values.addAll(set); Collections.sort(values); } featureValues.add(values); } for (int pass = 0; pass < numPasses; pass++) { for (int d = 0; d < maxDepth; d++) { double bestGain = Double.NEGATIVE_INFINITY; List splitCandidates = new ArrayList<>(); // Remove depth d Set processedKeys = new HashSet<>(); Map mapNew = new HashMap<>(map.size()); for (Map.Entry entry : map.entrySet()) { Long key = entry.getKey(); if (processedKeys.contains(key)) { continue; } Dataset data = entry.getValue(); int s = maxDepth - d - 1; Long otherKey = key ^ (1L << s); if (map.containsKey(otherKey)) { long check = (key >> s) & 1; Dataset left = null; Dataset right = null; if (check > 0) { left = data; right = map.get(otherKey); } else { left = map.get(otherKey); right = data; } // This key will be updated anyway mapNew.put(key, Dataset.merge(left, right)); processedKeys.add(key); processedKeys.add(otherKey); } else { mapNew.put(key, data); } } map = mapNew; for (int j = 0; j < attributes.size(); j++) { Attribute attribute = attributes.get(j); int attIndex = attribute.getIndex(); String attName = attribute.getName(); List values = featureValues.get(j); if (values.size() <= 1) { continue; } double[] gains = new double[values.size() - 1]; for (Dataset data : map.values()) { getStats(data.instances, stats); final double totalWeights = stats[0]; final double sum = stats[1]; List sortedList = data.sortedLists.get(attName); List uniqueValues = new ArrayList<>(sortedList.size()); List histogram = new ArrayList<>(sortedList.size()); getHistogram(data.instances, sortedList, uniqueValues, totalWeights, sum, histogram); double[] localGains = evalSplits(uniqueValues, histogram, totalWeights, sum); processGains(uniqueValues, localGains, values, gains); } int idx = StatUtils.indexOfMax(gains); if (bestGain <= gains[idx]) { double split = (values.get(idx) + values.get(idx + 1)) / 2; if (bestGain < gains[idx]) { bestGain = gains[idx]; splitCandidates.clear(); } splitCandidates.add(new IntDoublePair(attIndex, split)); } } if (splitCandidates.size() == 0) { break; } Random rand = Random.getInstance(); IntDoublePair split = splitCandidates.get(rand.nextInt(splitCandidates.size())); attIndices[d] = split.v1; splits[d] = split.v2; mapNew = new HashMap<>(map.size() * 2); for (Map.Entry entry : map.entrySet()) { Long key = entry.getKey(); Dataset data = entry.getValue(); Dataset left = new Dataset(data.instances); Dataset right = new Dataset(data.instances); data.split(split.v1, split.v2, left, right); int s = maxDepth - d - 1; if (left.instances.size() > 0) { Long leftKey = key | (1L << s); mapNew.put(leftKey, left); } if (right.instances.size() > 0) { Long rightKey = key & ~(1L << s); mapNew.put(rightKey, right); } } map = mapNew; } } List list = new ArrayList<>(splits.length); for (Map.Entry entry : map.entrySet()) { Long key = entry.getKey(); Dataset data = entry.getValue(); getStats(data.instances, stats); list.add(new LongDoublePair(key, stats[2])); } Collections.sort(list, new LongDoublePairComparator()); long[] predIndices = new long[list.size()]; double[] predValues = new double[list.size()]; for (int i = 0; i < predIndices.length; i++) { LongDoublePair pair = list.get(i); predIndices[i] = pair.v1; predValues[i] = pair.v2; } return new DecisionTable(attIndices, splits, predIndices, predValues); } /** * Builds an oblivious regression tree using multi-pass random backfitting. * * @param instances the training set. * @param maxDepth the maximum depth. * @param numPasses the number of passes. * @return an oblivious regression tree. */ public DecisionTable buildMultiPassRandom(Instances instances, int maxDepth, int numPasses) { // stats[0]: totalWeights // stats[1]: sum // stats[2]: weightedMean double[] stats = new double[3]; Map map = new HashMap<>(instances.size()); int[] attIndices = new int[maxDepth]; double[] splits = new double[maxDepth]; Dataset dataset = null; if (this.cache != null) { dataset = Dataset.create(this.cache, instances); } else { dataset = Dataset.create(instances); } map.put(Long.valueOf(0L), dataset); if (maxDepth <= 0) { getStats(dataset.instances, stats); final double weightedMean = stats[2]; return new DecisionTable( new int[] {}, new double[] {}, new long[] { 0L }, new double[] { weightedMean }); } List attributes = instances.getAttributes(); List> featureValues = new ArrayList<>(attributes.size()); for (int j = 0; j < attributes.size(); j++) { Attribute attribute = attributes.get(j); List values = new ArrayList<>(); if (attribute.getType() == Type.BINNED) { int numBins = ((BinnedAttribute) attribute).getNumBins(); for (int i = 0; i < numBins; i++) { values.add((double) i); } } else if (attribute.getType() == Type.NOMINAL) { int cardinality = ((NominalAttribute) attribute).getCardinality(); for (int i = 0; i < cardinality; i++) { values.add((double) i); } } else { Set set = new HashSet<>(); for (Instance instance : instances) { set.add(instance.getValue(attribute)); } values.addAll(set); Collections.sort(values); } featureValues.add(values); } for (int iter = 0; iter < numPasses; iter++) { for (int k = 0; k < maxDepth; k++) { double bestGain = Double.NEGATIVE_INFINITY; List splitCandidates = new ArrayList<>(); int d = k; if (iter > 0) { d = Random.getInstance().nextInt(maxDepth); } // Remove depth d Set processedKeys = new HashSet<>(); Map mapNew = new HashMap<>(map.size()); for (Map.Entry entry : map.entrySet()) { Long key = entry.getKey(); if (processedKeys.contains(key)) { continue; } Dataset data = entry.getValue(); int s = maxDepth - d - 1; Long otherKey = key ^ (1L << s); if (map.containsKey(otherKey)) { long check = (key >> s) & 1; Dataset left = null; Dataset right = null; if (check > 0) { left = data; right = map.get(otherKey); } else { left = map.get(otherKey); right = data; } // This key will be updated anyway mapNew.put(key, Dataset.merge(left, right)); processedKeys.add(key); processedKeys.add(otherKey); } else { mapNew.put(key, data); } } map = mapNew; for (int j = 0; j < attributes.size(); j++) { Attribute attribute = attributes.get(j); int attIndex = attribute.getIndex(); String attName = attribute.getName(); List values = featureValues.get(j); if (values.size() <= 1) { continue; } double[] gains = new double[values.size() - 1]; for (Dataset data : map.values()) { getStats(data.instances, stats); final double totalWeights = stats[0]; final double sum = stats[1]; List sortedList = data.sortedLists.get(attName); List uniqueValues = new ArrayList<>(sortedList.size()); List histogram = new ArrayList<>(sortedList.size()); getHistogram(data.instances, sortedList, uniqueValues, totalWeights, sum, histogram); double[] localGains = evalSplits(uniqueValues, histogram, totalWeights, sum); processGains(uniqueValues, localGains, values, gains); } int idx = StatUtils.indexOfMax(gains); if (bestGain <= gains[idx]) { double split = (values.get(idx) + values.get(idx + 1)) / 2; if (bestGain < gains[idx]) { bestGain = gains[idx]; splitCandidates.clear(); } splitCandidates.add(new IntDoublePair(attIndex, split)); } } if (splitCandidates.size() == 0) { break; } Random rand = Random.getInstance(); IntDoublePair split = splitCandidates.get(rand.nextInt(splitCandidates.size())); attIndices[d] = split.v1; splits[d] = split.v2; mapNew = new HashMap<>(map.size() * 2); for (Map.Entry entry : map.entrySet()) { Long key = entry.getKey(); Dataset data = entry.getValue(); Dataset left = new Dataset(data.instances); Dataset right = new Dataset(data.instances); data.split(split.v1, split.v2, left, right); int s = maxDepth - d - 1; if (left.instances.size() > 0) { Long leftKey = key | (1L << s); mapNew.put(leftKey, left); } if (right.instances.size() > 0) { Long rightKey = key & ~(1L << s); mapNew.put(rightKey, right); } } map = mapNew; } } List list = new ArrayList<>(splits.length); for (Map.Entry entry : map.entrySet()) { Long key = entry.getKey(); Dataset data = entry.getValue(); getStats(data.instances, stats); list.add(new LongDoublePair(key, stats[2])); } Collections.sort(list, new LongDoublePairComparator()); long[] predIndices = new long[list.size()]; double[] predValues = new double[list.size()]; for (int i = 0; i < predIndices.length; i++) { LongDoublePair pair = list.get(i); predIndices[i] = pair.v1; predValues[i] = pair.v2; } return new DecisionTable(attIndices, splits, predIndices, predValues); } protected void processGains(List uniqueValues, double[] localGains, List values, double[] gains) { int i = 0; int j = 0; double noSplitGain = localGains[localGains.length - 1]; double minV = uniqueValues.get(0); while (j < gains.length) { double v2 = values.get(j); if (v2 < minV) { gains[j] += noSplitGain; j++; } else { break; } } double prevGain = localGains[i]; while (i < localGains.length && j < gains.length) { double v1 = uniqueValues.get(i); double v2 = values.get(j); if (v1 == v2) { gains[j] += localGains[i]; prevGain = localGains[i]; i++; j++; } while (v1 > v2) { gains[j] += prevGain; j++; v2 = values.get(j); } } while (j < gains.length) { gains[j] += noSplitGain; j++; } } protected double[] evalSplits(List uniqueValues, List hist, double totalWeights, double sum) { double weight1 = hist.get(0).v1; double weight2 = totalWeights - weight1; double sum1 = hist.get(0).v2; double sum2 = sum - sum1; double[] gains = new double[uniqueValues.size()]; gains[0] = OptimUtils.getGain(sum1, weight1) + OptimUtils.getGain(sum2, weight2); for (int i = 1; i < uniqueValues.size() - 1; i++) { final double w = hist.get(i).v1; final double s = hist.get(i).v2; weight1 += w; weight2 -= w; sum1 += s; sum2 -= s; gains[i] = OptimUtils.getGain(sum1, weight1) + OptimUtils.getGain(sum2, weight2); } // gain for no split gains[uniqueValues.size() - 1] = OptimUtils.getGain(sum, totalWeights); return gains; } } ================================================ FILE: src/main/java/mltk/predictor/tree/RTree.java ================================================ package mltk.predictor.tree; import mltk.predictor.Regressor; /** * Interface for regression trees. * * @author Yin Lou * */ public interface RTree extends Regressor { /** * Multiplies this tree with a constant. * * @param c the constant. */ public void multiply(double c); /** * Returns a deep copy of this tree. */ public RTree copy(); } ================================================ FILE: src/main/java/mltk/predictor/tree/RTreeLearner.java ================================================ package mltk.predictor.tree; import java.util.Collections; import java.util.List; import mltk.core.Instance; import mltk.core.Instances; import mltk.util.tuple.DoublePair; import mltk.util.tuple.IntDoublePair; /** * Abstract class for learning regression trees. * * @author Yin Lou * */ public abstract class RTreeLearner extends TreeLearner { @Override public abstract RTree build(Instances instances); protected void getHistogram(Instances instances, List pairs, List uniqueValues, double w, double s, List histogram) { if (pairs.size() > 0) { double lastValue = pairs.get(0).v2; double totalWeight = instances.get(pairs.get(0).v1).getWeight(); double sum = instances.get(pairs.get(0).v1).getTarget() * totalWeight; for (int i = 1; i < pairs.size(); i++) { IntDoublePair pair = pairs.get(i); double value = pair.v2; double weight = instances.get(pairs.get(i).v1).getWeight(); double resp = instances.get(pairs.get(i).v1).getTarget(); if (value != lastValue) { uniqueValues.add(lastValue); histogram.add(new DoublePair(totalWeight, sum)); lastValue = value; totalWeight = weight; sum = resp * weight; } else { totalWeight += weight; sum += resp * weight; } } uniqueValues.add(lastValue); histogram.add(new DoublePair(totalWeight, sum)); } if (pairs.size() != instances.size()) { // Zero entries are present double sumWeight = 0; double sumTarget = 0; for (DoublePair pair : histogram) { sumWeight += pair.v1; sumTarget += pair.v2; } double weightOnZero = w - sumWeight; double sumOnZero = s - sumTarget; int idx = Collections.binarySearch(uniqueValues, ZERO); if (idx < 0) { // This should always happen uniqueValues.add(-idx - 1, ZERO); histogram.add(-idx - 1, new DoublePair(weightOnZero, sumOnZero)); } } } protected boolean getStats(Instances instances, double[] stats) { stats[0] = stats[1] = stats[2] = 0; if (instances.size() == 0) { return true; } double firstTarget = instances.get(0).getTarget(); boolean stdIs0 = true; for (Instance instance : instances) { double weight = instance.getWeight(); double target = instance.getTarget(); stats[0] += weight; stats[1] += weight * target; if (stdIs0 && target != firstTarget) { stdIs0 = false; } } stats[2] = stats[1] / stats[0]; if (Double.isNaN(stats[2])) { stats[2] = 0; } return stdIs0; } } ================================================ FILE: src/main/java/mltk/predictor/tree/RegressionTree.java ================================================ package mltk.predictor.tree; import java.io.BufferedReader; import java.io.PrintWriter; import mltk.core.Instance; /** * Class for regression trees. * * @author Yin Lou * */ public class RegressionTree implements RTree { /** * The root of a tree. */ protected TreeNode root; /** * Constructs an empty tree. */ public RegressionTree() { root = null; } /** * Constructs a regression tree with specified root. * * @param root the root. */ public RegressionTree(TreeNode root) { this.root = root; } /** * Returns the root of this regression tree. * * @return the root of this regression tree. */ public TreeNode getRoot() { return root; } /** * Sets the root for this regression tree. * * @param root the new root. */ public void setRoot(TreeNode root) { this.root = root; } /** * Returns the leaf node. * * @param instance the data point. * @return the leaf node. */ public RegressionTreeLeaf getLeafNode(Instance instance) { TreeNode node = root; while (!node.isLeaf()) { TreeInteriorNode interiorNode = (TreeInteriorNode) node; if (interiorNode.goLeft(instance)) { node = interiorNode.getLeftChild(); } else { node = interiorNode.getRightChild(); } } return (RegressionTreeLeaf) node; } /** * Multiplies this regression tree with a constant. * * @param c the constant. */ public void multiply(double c) { multiply(root, c); } /** * Multiplies this subtree with a constant. * * @param node the root of the subtree. * @param c the constant. */ protected void multiply(TreeNode node, double c) { if (node.isLeaf()) { RegressionTreeLeaf leaf = (RegressionTreeLeaf) node; leaf.prediction *= c; } else { TreeInteriorNode interiorNode = (TreeInteriorNode) node; multiply(interiorNode.left, c); multiply(interiorNode.right, c); } } @Override public double regress(Instance instance) { return getLeafNode(instance).getPrediction(); } @Override public void read(BufferedReader in) throws Exception { in.readLine(); Class clazz = Class.forName(in.readLine()); root = (TreeNode) clazz.getDeclaredConstructor().newInstance(); root.read(in); } @Override public void write(PrintWriter out) throws Exception { out.printf("[Predictor: %s]\n", this.getClass().getCanonicalName()); out.println(); root.write(out); } @Override public RegressionTree copy() { return new RegressionTree(root.copy()); } } ================================================ FILE: src/main/java/mltk/predictor/tree/RegressionTreeLeaf.java ================================================ package mltk.predictor.tree; import java.io.BufferedReader; import java.io.PrintWriter; /** * Class for regression tree leaves. * * @author Yin Lou * */ public class RegressionTreeLeaf extends TreeNode { protected double prediction; /** * Constructor. */ public RegressionTreeLeaf() { } /** * Constructs a leaf node with a constant prediction. * * @param prediction the prediction for this leaf node. */ public RegressionTreeLeaf(double prediction) { this.prediction = prediction; } @Override public boolean isLeaf() { return true; } /** * Sets the prediction for this leaf node. * * @param prediction the prediction for this leaf node. */ public void setPrediction(double prediction) { this.prediction = prediction; } /** * Returns the prediction for this leaf node. * * @return the prediction for this leaf node. */ public double getPrediction() { return prediction; } @Override public void read(BufferedReader in) throws Exception { prediction = Double.parseDouble(in.readLine().split(": ")[1]); } @Override public void write(PrintWriter out) throws Exception { out.println(this.getClass().getCanonicalName()); out.println("Prediction: " + prediction); } @Override public RegressionTreeLeaf copy() { return new RegressionTreeLeaf(prediction); } } ================================================ FILE: src/main/java/mltk/predictor/tree/RegressionTreeLearner.java ================================================ package mltk.predictor.tree; import java.util.ArrayList; import java.util.HashMap; import java.util.List; import java.util.Map; import java.util.PriorityQueue; import mltk.cmdline.Argument; import mltk.cmdline.CmdLineParser; import mltk.cmdline.options.LearnerOptions; import mltk.core.Attribute; import mltk.core.Instances; import mltk.core.io.InstancesReader; import mltk.predictor.evaluation.Evaluator; import mltk.predictor.io.PredictorWriter; import mltk.util.Random; import mltk.util.Stack; import mltk.util.Element; import mltk.util.OptimUtils; import mltk.util.tuple.DoublePair; import mltk.util.tuple.IntDoublePair; /** * Class for learning regression trees. * * @author Yin Lou * */ public class RegressionTreeLearner extends RTreeLearner { static class Options extends LearnerOptions { @Argument(name = "-m", description = "construction mode:parameter. Construction mode can be alpha limited (a), depth limited (d), number of leaves limited (l) and minimum leaf size limited (s) (default: a:0.001)") String mode = "a:0.001"; @Argument(name = "-s", description = "seed of the random number generator (default: 0)") long seed = 0L; } /** * Trains a regression tree. * *
	 * Usage: mltk.predictor.tree.RegressionTreeLearner
	 * -t	train set path
	 * [-r]	attribute file path
	 * [-o]	output model path
	 * [-V]	verbose (default: true)
	 * [-m]	construction mode:parameter. Construction mode can be alpha limited (a), depth limited (d), number of leaves limited (l) and minimum leaf size limited (s) (default: a:0.001)
	 * [-s]	seed of the random number generator (default: 0)
	 * 
* * @param args the command line arguments. * @throws Exception */ public static void main(String[] args) throws Exception { Options opts = new Options(); CmdLineParser parser = new CmdLineParser(RegressionTreeLearner.class, opts); RegressionTreeLearner learner = new RegressionTreeLearner(); try { parser.parse(args); learner.setParameters(opts.mode); } catch (IllegalArgumentException e) { parser.printUsage(); System.exit(1); } Random.getInstance().setSeed(opts.seed); Instances trainSet = InstancesReader.read(opts.attPath, opts.trainPath); long start = System.currentTimeMillis(); RegressionTree rt = learner.build(trainSet); long end = System.currentTimeMillis(); System.out.println("Time: " + (end - start) / 1000.0 + " (s)."); System.out.println(Evaluator.evalRMSE(rt, trainSet)); if (opts.outputModelPath != null) { PredictorWriter.write(rt, opts.outputModelPath); } } /** * Enumeration of construction mode. * * @author Yin Lou * */ public enum Mode { DEPTH_LIMITED, NUM_LEAVES_LIMITED, ALPHA_LIMITED, MIN_LEAF_SIZE_LIMITED; } protected int maxDepth; protected int maxNumLeaves; protected int minLeafSize; protected double alpha; protected Mode mode; /** * Constructor. */ public RegressionTreeLearner() { alpha = 0.01; mode = Mode.ALPHA_LIMITED; } @Override public RegressionTree build(Instances instances) { RegressionTree rt = null; switch (mode) { case ALPHA_LIMITED: rt = buildAlphaLimitedTree(instances, alpha); break; case NUM_LEAVES_LIMITED: rt = buildNumLeafLimitedTree(instances, maxNumLeaves); break; case DEPTH_LIMITED: rt = buildDepthLimitedTree(instances, maxDepth); break; case MIN_LEAF_SIZE_LIMITED: rt = buildMinLeafSizeLimitedTree(instances, minLeafSize); default: break; } return rt; } @Override public void setParameters(String mode) { String[] data = mode.split(":"); if (data.length != 2) { throw new IllegalArgumentException(); } switch (data[0]) { case "a": this.setConstructionMode(Mode.ALPHA_LIMITED); this.setAlpha(Double.parseDouble(data[1])); break; case "d": this.setConstructionMode(Mode.DEPTH_LIMITED); this.setMaxDepth(Integer.parseInt(data[1])); break; case "l": this.setConstructionMode(Mode.NUM_LEAVES_LIMITED); this.setMaxNumLeaves(Integer.parseInt(data[1])); break; case "s": this.setConstructionMode(Mode.MIN_LEAF_SIZE_LIMITED); this.setMinLeafSize(Integer.parseInt(data[1])); break; default: throw new IllegalArgumentException(); } } @Override public boolean isRobust() { return false; } /** * Returns the alpha. * * @return the alpha. */ public double getAlpha() { return alpha; } /** * Returns the construction mode. * * @return the construction mode. */ public Mode getConstructionMode() { return mode; } /** * Returns the maximum depth. * * @return the maximum depth. */ public int getMaxDepth() { return maxDepth; } /** * Returns the maximum number of leaves. * * @return the maximum number of leaves. */ public int getMaxNumLeaves() { return maxNumLeaves; } /** * Returns the minimum leaf size. * * @return the minimum leaf size. */ public int getMinLeafSize() { return minLeafSize; } /** * Sets the alpha. Alpha is the maximum proportion of the training set in the leaf node. * * @param alpha the alpha. */ public void setAlpha(double alpha) { this.alpha = alpha; } /** * Sets the construction mode. * * @param mode the construction mode. */ public void setConstructionMode(Mode mode) { this.mode = mode; } /** * Sets the maximum depth. * * @param maxDepth the maximum depth. */ public void setMaxDepth(int maxDepth) { this.maxDepth = maxDepth; } /** * Sets the maximum number of leaves. * * @param maxNumLeaves the maximum number of leaves. */ public void setMaxNumLeaves(int maxNumLeaves) { this.maxNumLeaves = maxNumLeaves; } /** * Sets the minimum leaf size. * * @param minLeafSize the minimum leaf size. */ public void setMinLeafSize(int minLeafSize) { this.minLeafSize = minLeafSize; } protected RegressionTree buildAlphaLimitedTree(Instances instances, double alpha) { final int limit = (int) (alpha * instances.size()); return buildMinLeafSizeLimitedTree(instances, limit); } protected RegressionTree buildDepthLimitedTree(Instances instances, int maxDepth) { RegressionTree tree = new RegressionTree(); final int limit = 5; // stats[0]: totalWeights // stats[1]: sum // stats[2]: weightedMean // stats[3]: splitEval double[] stats = new double[4]; if (maxDepth <= 0) { getStats(instances, stats); tree.root = new RegressionTreeLeaf(stats[1]); return tree; } Map datasets = new HashMap<>(); Map depths = new HashMap<>(); Dataset dataset = null; if (this.cache != null) { dataset = Dataset.create(this.cache, instances); } else { dataset = Dataset.create(instances); } tree.root = createNode(dataset, limit, stats); PriorityQueue> q = new PriorityQueue<>(); q.add(new Element(tree.root, stats[2])); datasets.put(tree.root, dataset); depths.put(tree.root, 0); while (!q.isEmpty()) { Element elemt = q.remove(); TreeNode node = elemt.element; Dataset data = datasets.get(node); int depth = depths.get(node); if (!node.isLeaf()) { TreeInteriorNode interiorNode = (TreeInteriorNode) node; Dataset left = new Dataset(data.instances); Dataset right = new Dataset(data.instances); split(data, interiorNode, left, right); if (depth >= maxDepth - 1) { getStats(left.instances, stats); interiorNode.left = new RegressionTreeLeaf(stats[2]); getStats(right.instances, stats); interiorNode.right = new RegressionTreeLeaf(stats[2]); } else { interiorNode.left = createNode(left, limit, stats); if (!interiorNode.left.isLeaf()) { q.add(new Element(interiorNode.left, stats[3])); datasets.put(interiorNode.left, left); depths.put(interiorNode.left, depth + 1); } interiorNode.right = createNode(right, limit, stats); if (!interiorNode.right.isLeaf()) { q.add(new Element(interiorNode.right, stats[3])); datasets.put(interiorNode.right, right); depths.put(interiorNode.right, depth + 1); } } } } return tree; } protected RegressionTree buildMinLeafSizeLimitedTree(Instances instances, int limit) { RegressionTree tree = new RegressionTree(); // stats[0]: totalWeights // stats[1]: sum // stats[2]: weightedMean // stats[3]: splitEval double[] stats = new double[4]; Dataset dataset = null; if (this.cache != null) { dataset = Dataset.create(this.cache, instances); } else { dataset = Dataset.create(instances); } Stack nodes = new Stack<>(); Stack datasets = new Stack<>(); tree.root = createNode(dataset, limit, stats); nodes.push(tree.root); datasets.push(dataset); while (!nodes.isEmpty()) { TreeNode node = nodes.pop(); Dataset data = datasets.pop(); if (!node.isLeaf()) { TreeInteriorNode interiorNode = (TreeInteriorNode) node; Dataset left = new Dataset(data.instances); Dataset right = new Dataset(data.instances); split(data, interiorNode, left, right); interiorNode.left = createNode(left, limit, stats); interiorNode.right = createNode(right, limit, stats); nodes.push(interiorNode.left); datasets.push(left); nodes.push(interiorNode.right); datasets.push(right); } } return tree; } protected RegressionTree buildNumLeafLimitedTree(Instances instances, int maxNumLeaves) { RegressionTree tree = new RegressionTree(); final int limit = 5; // stats[0]: totalWeights // stats[1]: sum // stats[2]: weightedMean // stats[3]: splitEval double[] stats = new double[4]; Map nodePred = new HashMap<>(); Map datasets = new HashMap<>(); Dataset dataset = null; if (this.cache != null) { dataset = Dataset.create(this.cache, instances); } else { dataset = Dataset.create(instances); } PriorityQueue> q = new PriorityQueue<>(); tree.root = createNode(dataset, limit, stats); q.add(new Element(tree.root, stats[2])); datasets.put(tree.root, dataset); nodePred.put(tree.root, stats[1]); int numLeaves = 0; while (!q.isEmpty()) { Element elemt = q.remove(); TreeNode node = elemt.element; Dataset data = datasets.get(node); if (!node.isLeaf()) { TreeInteriorNode interiorNode = (TreeInteriorNode) node; Dataset left = new Dataset(data.instances); Dataset right = new Dataset(data.instances); split(data, interiorNode, left, right); interiorNode.left = createNode(left, limit, stats); if (!interiorNode.left.isLeaf()) { nodePred.put(interiorNode.left, stats[2]); q.add(new Element(interiorNode.left, stats[3])); datasets.put(interiorNode.left, left); } else { numLeaves++; } interiorNode.right = createNode(right, limit, stats); if (!interiorNode.right.isLeaf()) { nodePred.put(interiorNode.right, stats[2]); q.add(new Element(interiorNode.right, stats[3])); datasets.put(interiorNode.right, right); } else { numLeaves++; } if (numLeaves + q.size() >= maxNumLeaves) { break; } } } // Convert interior nodes to leaves Map parent = new HashMap<>(); traverse(tree.root, parent); while (!q.isEmpty()) { Element elemt = q.remove(); TreeNode node = elemt.element; double prediction = nodePred.get(node); TreeInteriorNode interiorNode = (TreeInteriorNode) parent.get(node); if (interiorNode.left == node) { interiorNode.left = new RegressionTreeLeaf(prediction); } else { interiorNode.right = new RegressionTreeLeaf(prediction); } } return tree; } protected TreeNode createNode(Dataset dataset, int limit, double[] stats) { boolean stdIs0 = getStats(dataset.instances, stats); final double totalWeights = stats[0]; final double sum = stats[1]; final double weightedMean = stats[2]; // 1. Check basic leaf conditions if (dataset.instances.size() < limit || stdIs0) { TreeNode node = new RegressionTreeLeaf(weightedMean); return node; } // 2. Find best split double bestEval = Double.POSITIVE_INFINITY; List splits = new ArrayList<>(); List attributes = dataset.instances.getAttributes(); for (int j = 0; j < attributes.size(); j++) { int attIndex = attributes.get(j).getIndex(); String attName = attributes.get(j).getName(); List sortedList = dataset.sortedLists.get(attName); List uniqueValues = new ArrayList<>(sortedList.size()); List histogram = new ArrayList<>(sortedList.size()); getHistogram(dataset.instances, sortedList, uniqueValues, totalWeights, sum, histogram); if (uniqueValues.size() > 1) { DoublePair split = split(uniqueValues, histogram, totalWeights, sum); if (split.v2 <= bestEval) { IntDoublePair splitPoint = new IntDoublePair(attIndex, split.v1); if (split.v2 < bestEval) { splits.clear(); bestEval = split.v2; } splits.add(splitPoint); } } } if (bestEval < Double.POSITIVE_INFINITY) { Random rand = Random.getInstance(); IntDoublePair splitPoint = splits.get(rand.nextInt(splits.size())); int attIndex = splitPoint.v1; TreeNode node = new TreeInteriorNode(attIndex, splitPoint.v2); stats[3] = bestEval + totalWeights * weightedMean * weightedMean; return node; } else { TreeNode node = new RegressionTreeLeaf(weightedMean); return node; } } protected void split(Dataset data, TreeInteriorNode node, Dataset left, Dataset right) { data.split(node.getSplitAttributeIndex(), node.getSplitPoint(), left, right); } protected DoublePair split(List uniqueValues, List hist, double totalWeights, double sum) { double weight1 = hist.get(0).v1; double weight2 = totalWeights - weight1; double sum1 = hist.get(0).v2; double sum2 = sum - sum1; double bestEval = -(OptimUtils.getGain(sum1, weight1) + OptimUtils.getGain(sum2, weight2)); List splits = new ArrayList<>(); splits.add((uniqueValues.get(0) + uniqueValues.get(0 + 1)) / 2); for (int i = 1; i < uniqueValues.size() - 1; i++) { final double w = hist.get(i).v1; final double s = hist.get(i).v2; weight1 += w; weight2 -= w; sum1 += s; sum2 -= s; double eval1 = OptimUtils.getGain(sum1, weight1); double eval2 = OptimUtils.getGain(sum2, weight2); double eval = -(eval1 + eval2); if (eval <= bestEval) { double split = (uniqueValues.get(i) + uniqueValues.get(i + 1)) / 2; if (eval < bestEval) { bestEval = eval; splits.clear(); } splits.add(split); } } Random rand = Random.getInstance(); double split = splits.get(rand.nextInt(splits.size())); return new DoublePair(split, bestEval); } protected void traverse(TreeNode node, Map parent) { if (!node.isLeaf()) { TreeInteriorNode interiorNode = (TreeInteriorNode) node; if (interiorNode.left != null) { parent.put(interiorNode.left, node); traverse(interiorNode.left, parent); } if (interiorNode.right != null) { parent.put(interiorNode.right, node); traverse(interiorNode.right, parent); } } } } ================================================ FILE: src/main/java/mltk/predictor/tree/TreeInteriorNode.java ================================================ package mltk.predictor.tree; import java.io.BufferedReader; import java.io.PrintWriter; import mltk.core.Instance; /** * Class for tree interior nodes. * * @author Yin Lou * */ public class TreeInteriorNode extends TreeNode { protected TreeNode left; protected TreeNode right; protected int attIndex; protected double splitPoint; /** * Constructor. */ public TreeInteriorNode() { } /** * Constructs an interior node with attribute index and split point. * * @param attIndex the attribute index. * @param splitPoint the split point. */ public TreeInteriorNode(int attIndex, double splitPoint) { this.attIndex = attIndex; this.splitPoint = splitPoint; } /** * Returns the left child. * * @return the left child. */ public TreeNode getLeftChild() { return left; } /** * Returns the right child. * * @return the right child. */ public TreeNode getRightChild() { return right; } /** * Returns the split attribute index. * * @return the split attribute index. */ public int getSplitAttributeIndex() { return attIndex; } /** * Returns the split point. * * @return the split point. */ public double getSplitPoint() { return splitPoint; } @Override public boolean isLeaf() { return false; } /** * Returns {@code true} if going to left child. * * @param instance the instance. * @return {@code true} if going to left child. */ public boolean goLeft(Instance instance) { double value = instance.getValue(attIndex); return value <= splitPoint; } @Override public void read(BufferedReader in) throws Exception { attIndex = Integer.parseInt(in.readLine().split(": ")[1]); splitPoint = Double.parseDouble(in.readLine().split(": ")[1]); in.readLine(); Class clazzLeft = Class.forName(in.readLine()); left = (TreeNode) clazzLeft.getDeclaredConstructor().newInstance(); left.read(in); in.readLine(); Class clazzRight = Class.forName(in.readLine()); right = (TreeNode) clazzRight.getDeclaredConstructor().newInstance(); right.read(in); } @Override public void write(PrintWriter out) throws Exception { out.println(this.getClass().getCanonicalName()); out.println("AttIndex: " + attIndex); out.println("SplintPoint: " + splitPoint); out.println(); left.write(out); out.println(); right.write(out); } @Override public TreeNode copy() { TreeInteriorNode copy = new TreeInteriorNode(attIndex, splitPoint); copy.left = this.left.copy(); copy.right = this.right.copy(); return copy; } } ================================================ FILE: src/main/java/mltk/predictor/tree/TreeLearner.java ================================================ package mltk.predictor.tree; import java.util.ArrayList; import java.util.Arrays; import java.util.Collections; import java.util.HashMap; import java.util.List; import java.util.Map; import mltk.core.Attribute; import mltk.core.Instance; import mltk.core.Instances; import mltk.core.SparseVector; import mltk.predictor.Learner; import mltk.util.tuple.IntDoublePair; import mltk.util.tuple.IntDoublePairComparator; /** * Abstract class for learning trees. * * @author Yin Lou * */ public abstract class TreeLearner extends Learner { protected static final Double ZERO = Double.valueOf(0.0); protected static final IntDoublePairComparator COMP = new IntDoublePairComparator(false); protected Dataset cache; /** * Returns {@code true} if this tree learner can be used in {@link mltk.predictor.tree.ensemble.brt.LogitBoostLearner}. * * @return {@code true} if this tree learner can be used in {@link mltk.predictor.tree.ensemble.brt.LogitBoostLearner}. */ public abstract boolean isRobust(); /** * Caches the auxiliary data structures. This method is used in ensemble method * so that same data structures can be shared across iterations. * * @param instances the instances. */ public void cache(Instances instances) { cache = Dataset.create(instances); } /** * Evicts the cached data structures. */ public void evictCache() { cache = null; } /** * Sets the parameters for this tree learner. * * @param mode the parameters. */ public abstract void setParameters(String mode); protected static class Dataset { static Dataset create(Instances instances) { Dataset dataset = new Dataset(instances); List attributes = instances.getAttributes(); // From attIndex to attName Map attMap = new HashMap<>(); for (int j = 0; j < attributes.size(); j++) { Attribute attribute = attributes.get(j); attMap.put(attribute.getIndex(), attribute.getName()); } for (Attribute attribute : attributes) { dataset.sortedLists.put(attribute.getName(), new ArrayList()); } for (int i = 0; i < instances.size(); i++) { Instance instance = instances.get(i); dataset.instances.add(instance.clone()); if (instance.isSparse()) { SparseVector sv = (SparseVector) instance.getVector(); int[] indices = sv.getIndices(); double[] values = sv.getValues(); for (int k = 0; k < indices.length; k++) { if (attMap.containsKey(indices[k])) { String attName = attMap.get(indices[k]); dataset.sortedLists.get(attName).add(new IntDoublePair(i, values[k])); } } } else { double[] values = instance.getValues(); for (int j = 0; j < values.length; j++) { if (attMap.containsKey(j) && values[j] != 0.0) { String attName = attMap.get(j); dataset.sortedLists.get(attName).add(new IntDoublePair(i, values[j])); } } } } for (List sortedList : dataset.sortedLists.values()) { Collections.sort(sortedList, COMP); } return dataset; } static Dataset create(Dataset dataset, Instances instances) { Dataset copy = new Dataset(); copy.instances = instances; copy.sortedLists = new HashMap<>(instances.dimension()); List attributes = instances.getAttributes(); for (Attribute attribute : attributes) { String attName = attribute.getName(); List sortedList = dataset.sortedLists.get(attName); if (sortedList == null) { // This should not happen very often sortedList = new ArrayList<>(); for (int i = 0; i < instances.size(); i++) { Instance instance = instances.get(i); double v = instance.getValue(attribute); if (v != 0.0) { sortedList.add(new IntDoublePair(i, v)); } } Collections.sort(sortedList, COMP); dataset.sortedLists.put(attName, sortedList); } List copySortedList = new ArrayList<>(sortedList.size()); for (IntDoublePair pair : sortedList) { copySortedList.add(new IntDoublePair(pair.v1, pair.v2)); } copy.sortedLists.put(attName, copySortedList); } return copy; } public Instances instances; public Map> sortedLists; Dataset() { } Dataset(Instances instances) { this.instances = new Instances(instances.getAttributes(), instances.getTargetAttribute()); sortedLists = new HashMap<>(instances.dimension()); } static Dataset merge(Dataset left, Dataset right) { Dataset data = new Dataset(left.instances); int lSize = left.instances.size(); for (Instance instance : left.instances) { data.instances.add(instance); } for (Instance instance : right.instances) { data.instances.add(instance); } for (String attName : left.sortedLists.keySet()) { List lSortedList = left.sortedLists.get(attName); List rSortedList = right.sortedLists.get(attName); List sortedList = new ArrayList<>(data.instances.size()); int i = 0; int j = 0; while (i < lSortedList.size() && j < rSortedList.size()) { IntDoublePair l = lSortedList.get(i); IntDoublePair r = rSortedList.get(j); if (l.v2 < r.v2) { sortedList.add(l); i++; } else if (l.v2 > r.v2) { r.v1 += + lSize; sortedList.add(r); j++; } else { sortedList.add(l); r.v1 += lSize; sortedList.add(r); i++; j++; } } while (i < lSortedList.size()) { IntDoublePair l = lSortedList.get(i); sortedList.add(l); i++; } while (j < rSortedList.size()) { IntDoublePair r = rSortedList.get(j); r.v1 += lSize; sortedList.add(r); j++; } data.sortedLists.put(attName, sortedList); } return data; } void split(int attIndex, double split, Dataset left, Dataset right) { int[] leftHash = new int[instances.size()]; int[] rightHash = new int[instances.size()]; Arrays.fill(leftHash, -1); Arrays.fill(rightHash, -1); for (int i = 0; i < instances.size(); i++) { Instance instance = instances.get(i); if (instance.getValue(attIndex) <= split) { left.instances.add(instance); leftHash[i] = left.instances.size() - 1; } else { right.instances.add(instance); rightHash[i] = right.instances.size() - 1; } } for (String attName : sortedLists.keySet()) { left.sortedLists.put(attName, new ArrayList(left.instances.size())); right.sortedLists.put(attName, new ArrayList(right.instances.size())); List sortedList = sortedLists.get(attName); for (IntDoublePair pair : sortedList) { int leftIdx = leftHash[pair.v1]; int rightIdx = rightHash[pair.v1]; if (leftIdx != -1) { pair.v1 = leftIdx; left.sortedLists.get(attName).add(pair); } if (rightIdx != -1) { pair.v1 = rightIdx; right.sortedLists.get(attName).add(pair); } } } } } } ================================================ FILE: src/main/java/mltk/predictor/tree/TreeNode.java ================================================ package mltk.predictor.tree; import mltk.core.Copyable; import mltk.core.Writable; /** * Abstract class for tree nodes. * * @author Yin Lou * */ public abstract class TreeNode implements Writable, Copyable { /** * Returns {@code true} if the node is a leaf. * * @return {@code true} if the node is a leaf. */ public abstract boolean isLeaf(); } ================================================ FILE: src/main/java/mltk/predictor/tree/ensemble/BaggedRTrees.java ================================================ package mltk.predictor.tree.ensemble; import java.util.ArrayList; import mltk.core.Instance; import mltk.predictor.tree.RTree; /** * Class for bagged regression trees. * * @author Yin Lou * */ public class BaggedRTrees extends RTreeList { /** * Constructor. */ public BaggedRTrees() { super(); } /** * Constructs a regression tree list of length n. By default each tree is null. * * @param n the length. */ public BaggedRTrees(int n) { trees = new ArrayList<>(n); for (int i = 0; i < n; i++) { trees.add(null); } } /** * Regresses an instance. * * @param instance the instance. * @return the regressed */ public double regress(Instance instance) { double pred = 0; for (RTree rt : trees) { pred += rt.regress(instance); } return pred / trees.size(); } } ================================================ FILE: src/main/java/mltk/predictor/tree/ensemble/BoostedDTables.java ================================================ package mltk.predictor.tree.ensemble; import java.io.BufferedReader; import java.io.PrintWriter; import java.util.ArrayList; import java.util.Collections; import java.util.HashMap; import java.util.List; import java.util.Map; import mltk.core.Copyable; import mltk.core.Instance; import mltk.core.SparseVector; import mltk.predictor.tree.DecisionTable; import mltk.predictor.tree.RTree; /** * Class for boosted decision tables. * * @author Yin Lou * */ public class BoostedDTables implements Copyable { static class IndexElement implements Comparable { double cut; int tid; int pos; public IndexElement(double cut, int tid, int pos) { this.cut = cut; this.tid = tid; this.pos = pos; } @Override public int compareTo(IndexElement o) { return Double.compare(o.cut, this.cut); } public IndexElement copy() { return new IndexElement(cut, tid, pos); } } static class Index { IndexElement[] elements; public Index(IndexElement[] elements) { this.elements = elements; } void setPredIdx(long[] predIndices, double v) { for (IndexElement element : elements) { if (v <= element.cut) { long t = 1 << element.pos; int tid = element.tid; predIndices[tid] |= t; } else { break; } } } public Index copy() { IndexElement[] elementsCopy = new IndexElement[elements.length]; for (int i = 0; i < elementsCopy.length; i++) { elementsCopy[i] = elements[i].copy(); } return new Index(elementsCopy); } } protected static final IndexElement[] EMPTY_INDEX = new IndexElement[0]; protected List dtList; protected Index[] indexes; /** * Constructor. */ public BoostedDTables() { dtList = new ArrayList<>(); } /** * Constructor. * * @param trees */ public BoostedDTables(BoostedRTrees trees) { dtList = new ArrayList<>(trees.size()); for (RTree tree : trees) { dtList.add((DecisionTable) tree); } buildIndex(); } /** * Builds the index for fast scoring. */ public void buildIndex() { Map> map = new HashMap<>(); int p = -1; for (int i = 0; i < dtList.size(); i++) { DecisionTable dt = dtList.get(i); int[] attIndices = dt.getAttributeIndices(); double[] cuts = dt.getSplits(); for (int k = 0; k < attIndices.length; k++) { IndexElement element = new IndexElement(cuts[k], i, attIndices.length - k - 1); if (attIndices[k] > p) { p = attIndices[k]; } int attIdx = attIndices[k]; if (!map.containsKey(attIdx)) { map.put(attIdx, new ArrayList()); } map.get(attIdx).add(element); } } p++; indexes = new Index[p]; for (int j = 0; j < p; j++) { List elements = map.get(j); if (elements != null) { Collections.sort(elements); indexes[j] = new Index(elements.toArray(new IndexElement[elements.size()])); } else { indexes[j] = new Index(EMPTY_INDEX); } } } /** * Adds a decision table to the list. * * @param dt the decision table to add. */ public void add(DecisionTable dt) { dtList.add(dt); } /** * Returns the table at the specified position in this list. * * @param index the index of the element to return. * @return the table at the specified position in this list. */ public DecisionTable get(int index) { return dtList.get(index); } /** * Removes the last tree. */ public void removeLast() { if (dtList.size() > 0) { dtList.remove(dtList.size() - 1); } } /** * Returns the size of this boosted decision table list. * * @return the size of this boosted decision table list. */ public int size() { return dtList.size(); } /** * Replaces the table at the specified position in this list with the new table. * * @param index the index of the element to replace. * @param dt the decision table to be stored at the specified position. */ public void set(int index, DecisionTable dt) { dtList.set(index, dt); } /** * Regresses an instance. * * @param instance the instance to regress. * @return a regressed value. */ public double regress(Instance instance) { double[] values = instance.getValues(); long[] predIndices = new long[dtList.size()]; if (instance.isSparse()) { int[] indices = ((SparseVector) instance.getVector()).getIndices(); for (int j = 0; j < Math.min(values.length, indexes.length); j++) { double v = values[j]; int idx = indices[j]; indexes[idx].setPredIdx(predIndices, v); } } else { for (int j = 0; j < Math.min(values.length, indexes.length); j++) { double v = values[j]; indexes[j].setPredIdx(predIndices, v); } } double pred = 0; for (int i = 0 ; i < predIndices.length; i++) { DecisionTable dt = dtList.get(i); pred += dt.regress(predIndices[i]); } return pred; } @Override public BoostedDTables copy() { return copy(true); } /** * Copies this object. * * @param copyIndexes {@code true} if the indexes are also copied; * @return this object. */ public BoostedDTables copy(boolean copyIndexes) { BoostedDTables copy = new BoostedDTables(); copy.dtList = new ArrayList<>(); for (DecisionTable dt : dtList) { copy.dtList.add(dt.copy()); } if (copyIndexes) { copy.indexes = new Index[this.indexes.length]; for (int i = 0; i < copy.indexes.length; i++) { copy.indexes[i] = indexes[i].copy(); } } return copy; } /** * Reads in this boosted decision tables. * * @param in the reader. * @throws Exception */ public void read(BufferedReader in) throws Exception { int n = Integer.parseInt(in.readLine().split(": ")[1]); for (int j = 0; j < n; j++) { String line = in.readLine(); String predictorName = line.substring(1, line.length() - 1).split(": ")[1]; Class clazz = Class.forName(predictorName); DecisionTable dt = (DecisionTable) clazz.getDeclaredConstructor().newInstance(); dt.read(in); this.dtList.add(dt); in.readLine(); } buildIndex(); } /** * Writes this boosted decision tables. * * @param out the writer. * @throws Exception */ public void write(PrintWriter out) throws Exception { out.printf("[Predictor: %s]\n", this.getClass().getCanonicalName()); out.println("Length: " + dtList.size()); for (DecisionTable dt : dtList) { dt.write(out); out.println(); } } } ================================================ FILE: src/main/java/mltk/predictor/tree/ensemble/BoostedRTrees.java ================================================ package mltk.predictor.tree.ensemble; import java.io.BufferedReader; import java.io.PrintWriter; import java.util.ArrayList; import mltk.core.Instance; import mltk.predictor.tree.RTree; /** * Class for boosted regression trees. This is a base class for BRT. * * @author Yin Lou * */ public class BoostedRTrees extends RTreeList { /** * Constructor. */ public BoostedRTrees() { super(); } /** * Constructs a regression tree list of length n. By default each tree is null. * * @param n the length. */ public BoostedRTrees(int n) { trees = new ArrayList<>(n); for (int i = 0; i < n; i++) { trees.add(null); } } /** * Regresses an instance. * * @param instance the instance. * @return the regressed */ public double regress(Instance instance) { double pred = 0; for (RTree rt : trees) { pred += rt.regress(instance); } return pred; } @Override public BoostedRTrees copy() { BoostedRTrees copy = new BoostedRTrees(); for (RTree rt : trees) { copy.trees.add(rt.copy()); } return copy; } public void read(BufferedReader in) throws Exception { int n = Integer.parseInt(in.readLine().split(": ")[1]); for (int j = 0; j < n; j++) { String line = in.readLine(); String predictorName = line.substring(1, line.length() - 1).split(": ")[1]; Class clazz = Class.forName(predictorName); RTree rt = (RTree) clazz.getDeclaredConstructor().newInstance(); rt.read(in); this.trees.add(rt); in.readLine(); } } public void write(PrintWriter out) throws Exception { out.printf("[Predictor: %s]\n", this.getClass().getCanonicalName()); out.println("Length: " + trees.size()); for (RTree rt : trees) { rt.write(out); out.println(); } } } ================================================ FILE: src/main/java/mltk/predictor/tree/ensemble/RTreeList.java ================================================ package mltk.predictor.tree.ensemble; import java.util.ArrayList; import java.util.Iterator; import java.util.List; import mltk.core.Copyable; import mltk.predictor.tree.RTree; /** * Class for regression tree list. * * @author Yin Lou * */ public class RTreeList implements Iterable, Copyable { protected List trees; /** * Constructor. */ public RTreeList() { trees = new ArrayList<>(); } /** * Constructor. * * @param capacity the capacity of this tree list. */ public RTreeList(int capacity) { trees = new ArrayList<>(capacity); } /** * Adds a regression tree to the list. * * @param tree the regression tree to add. */ public void add(RTree tree) { trees.add(tree); } @Override public RTreeList copy() { RTreeList copy = new RTreeList(); for (RTree rt : trees) { copy.trees.add(rt.copy()); } return copy; } /** * Returns the tree at the specified position in this list. * * @param index the index of the element to return. * @return the tree at the specified position in this list. */ public RTree get(int index) { return trees.get(index); } @Override public Iterator iterator() { return trees.iterator(); } /** * Removes the last tree. */ public void removeLast() { if (trees.size() > 0) { trees.remove(trees.size() - 1); } } /** * Replaces the tree at the specified position in this list with the new tree. * * @param index the index of the element to replace. * @param rt the regression tree to be stored at the specified position. */ public void set(int index, RTree rt) { trees.set(index, rt); } /** * Returns the size of this list. * * @return the size of this list. */ public int size() { return trees.size(); } } ================================================ FILE: src/main/java/mltk/predictor/tree/ensemble/TreeEnsembleLearner.java ================================================ package mltk.predictor.tree.ensemble; import mltk.predictor.HoldoutValidatedLearner; import mltk.predictor.tree.TreeLearner; /** * Class for learning tree ensembles. * * @author Yin Lou * */ public abstract class TreeEnsembleLearner extends HoldoutValidatedLearner { protected TreeLearner treeLearner; public TreeLearner getTreeLearner() { return treeLearner; } public void setTreeLearner(TreeLearner treeLearner) { this.treeLearner = treeLearner; } } ================================================ FILE: src/main/java/mltk/predictor/tree/ensemble/ag/AdditiveGroves.java ================================================ package mltk.predictor.tree.ensemble.ag; import java.io.BufferedReader; import java.io.PrintWriter; import java.util.ArrayList; import java.util.List; import mltk.core.Instance; import mltk.predictor.Regressor; import mltk.predictor.tree.RegressionTree; /** * Class for Additive Groves. * * @author Yin Lou * */ public class AdditiveGroves implements Regressor { protected List groves; /** * Constructor. */ public AdditiveGroves() { groves = new ArrayList<>(); } @Override public void read(BufferedReader in) throws Exception { int bn = Integer.parseInt(in.readLine().split(": ")[1]); groves = new ArrayList<>(); for (int i = 0; i < bn; i++) { int tn = Integer.parseInt(in.readLine().split(": ")[1]); RegressionTree[] grove = new RegressionTree[tn]; for (int j = 0; j < tn; j++) { in.readLine(); RegressionTree rt = new RegressionTree(); rt.read(in); grove[i] = rt; in.readLine(); } groves.add(grove); in.readLine(); } } @Override public void write(PrintWriter out) throws Exception { out.printf("[Predictor: %s]\n", this.getClass().getCanonicalName()); out.println("Bagging: " + groves.size()); for (RegressionTree[] grove : groves) { out.println("Size: " + grove.length); for (RegressionTree rt : grove) { rt.write(out); out.println(); } out.println(); } } @Override public double regress(Instance instance) { if (groves.size() == 0) { return 0; } double pred = 0; for (RegressionTree[] grove : groves) { for (RegressionTree rt : grove) { pred += rt.regress(instance); } } return pred / groves.size(); } @Override public AdditiveGroves copy() { AdditiveGroves copy = new AdditiveGroves(); for (RegressionTree[] grove : groves) { RegressionTree[] copyGrove = new RegressionTree[grove.length]; for (int i = 0; i < copyGrove.length; i++) { copyGrove[i] = grove[i].copy(); } copy.groves.add(copyGrove); } return copy; } } ================================================ FILE: src/main/java/mltk/predictor/tree/ensemble/ag/AdditiveGrovesLearner.java ================================================ package mltk.predictor.tree.ensemble.ag; import java.util.ArrayList; import java.util.HashMap; import java.util.List; import java.util.Map; import mltk.cmdline.Argument; import mltk.cmdline.CmdLineParser; import mltk.cmdline.options.LearnerOptions; import mltk.core.Instance; import mltk.core.Instances; import mltk.core.Sampling; import mltk.core.io.InstancesReader; import mltk.predictor.Learner; import mltk.predictor.evaluation.AUC; import mltk.predictor.evaluation.RMSE; import mltk.predictor.evaluation.SimpleMetric; import mltk.predictor.io.PredictorWriter; import mltk.predictor.tree.RegressionTree; import mltk.predictor.tree.RegressionTreeLearner; import mltk.predictor.tree.RegressionTreeLearner.Mode; import mltk.util.OptimUtils; import mltk.util.Random; import mltk.util.tuple.IntPair; /** * Class for learning Additive Groves. This class currently only supports layered training. * *

* Reference:
* D. Sorokina, R. Caruana and M. Riedewald. Additive Groves of Regression Trees. In Proceedings of the 18th European * Conference on Machine Learning (ECML), Warsaw, Poland, 2007. *

* * @author Yin Lou * */ public class AdditiveGrovesLearner extends Learner { static class Options extends LearnerOptions { @Argument(name = "-v", description = "valid set path", required = true) String validPath = null; @Argument(name = "-o", description = "output model path") String outputModelPath = null; @Argument(name = "-e", description = "AUC (a), RMSE (r) (default: r)") String metric = "rmse"; @Argument(name = "-b", description = "bagging iterations (default: 60)") int baggingIters = 60; @Argument(name = "-n", description = "number of trees in a grove (default: 6)") int n = 6; @Argument(name = "-a", description = "minimum alpha (default: 0.01)") double a = 0.01; @Argument(name = "-s", description = "seed of the random number generator (default: 0)") long seed = 0L; } /** * Trains an additive groves model. * *
	 * Usage: mltk.predictor.tree.ensemble.ag.AdditiveGrovesLearner
	 * -t	train set path
	 * -v	valid set path
	 * [-r]	attribute file path
	 * [-o]	output model path
	 * [-V]	verbose (default: true)
	 * [-o]	output model path
	 * [-e]	AUC (a), RMSE (r) (default: r)
	 * [-b]	bagging iterations (default: 60)
	 * [-n]	number of trees in a grove (default: 6)
	 * [-a]	minimum alpha (default: 0.01)
	 * [-s]	seed of the random number generator (default: 0)
	 * 
* * @param args the command line arguments. * @throws Exception */ public static void main(String[] args) throws Exception { Options opts = new Options(); CmdLineParser parser = new CmdLineParser(AdditiveGrovesLearner.class, opts); SimpleMetric metric = null; try { parser.parse(args); if ("rmse".startsWith(opts.metric)) { metric = new RMSE(); } else if ("auc".startsWith(opts.metric)) { metric = new AUC(); } } catch (IllegalArgumentException e) { parser.printUsage(); System.exit(1); } Random.getInstance().setSeed(opts.seed); Instances trainSet = InstancesReader.read(opts.attPath, opts.trainPath); Instances validSet = InstancesReader.read(opts.attPath, opts.validPath); AdditiveGrovesLearner learner = new AdditiveGrovesLearner(); learner.setBaggingIters(opts.baggingIters); learner.setNumTrees(opts.n); learner.setMinAlpha(opts.a); learner.setMetric(metric); learner.setVerbose(opts.verbose); long start = System.currentTimeMillis(); AdditiveGroves ag = learner.buildRegressor(trainSet, validSet); long end = System.currentTimeMillis(); System.out.println("Time: " + (end - start) / 1000.0); if (opts.outputModelPath != null) { PredictorWriter.write(ag, opts.outputModelPath); } } class PerformanceMatrix { SimpleMetric metric; double[][][] perf; PerformanceMatrix(int maxNumTrees, int numAlphas, int baggingIters, SimpleMetric metric) { perf = new double[maxNumTrees][numAlphas][baggingIters]; this.metric = metric; } void expand(int maxNumTrees, int numAlphas, int baggingIters) { // TODO make sure all dimensions are increasing only double[][][] newPerf = new double[maxNumTrees][numAlphas][baggingIters]; for (int i = 0; i < perf.length; i++) { double[][] t1Old = perf[i]; double[][] t1New = newPerf[i]; for (int j = 0; j < t1Old.length; j++) { double[] t2Old = t1Old[j]; double[] t2New = t1New[j]; for (int k = 0; k < t2Old.length; k++) { t2New[k] = t2Old[k]; } } } perf = newPerf; } void eval(int t, int a, int b, double[] preds, double[] targets) { perf[t][a][b] = metric.eval(preds, targets); } IntPair getBestParameters() { int bestT = 0; int bestA = 0; double bestPerf = metric.worstValue(); if (verbose) { System.out.println("Perf Matrix:"); for (int i = 0; i < perf.length; i++) { double[][] tOutter = perf[i]; for (int j = 0; j < tOutter.length; j++) { double[] tInner = tOutter[j]; double p = tInner[tInner.length - 1]; System.out.print(p + " "); if (metric.isFirstBetter(p, bestPerf)) { bestT = i; bestA = j; bestPerf = p; } } System.out.println(); } System.out.println("Best perf on validation set = " + bestPerf); } return new IntPair(bestT, bestA); } /** * Returns {@code true} if the bagging converges. * * @param t the number of trees. * @param a the index of alpha. * @return {@code true} if the bagging converges. */ boolean analyzeBagging(int t, int a) { return OptimUtils.isConverged(perf[t][a], metric.isLargerBetter()); } } class ModelMatrix { AdditiveGroves[][] groves; ModelMatrix(int maxNumTrees, int numAlphas) { groves = new AdditiveGroves[maxNumTrees][numAlphas]; } void expand(int maxNumTrees, int numAlphas, int baggingIters) { AdditiveGroves[][] newGroves = new AdditiveGroves[maxNumTrees][numAlphas]; for (int i = 0; i < groves.length; i++) { for (int j = 0; j < groves[i].length; j++) { newGroves[i][j] = groves[i][j]; } } groves = newGroves; } void add(int t, int a, RegressionTree[] grove) { if (groves[t][a] == null) { groves[t][a] = new AdditiveGroves(); } groves[t][a].groves.add(grove); } } class PredictionMatrix { double[][][] sumPrediction; int n; PredictionMatrix(int tn, int an, int n) { sumPrediction = new double[tn][an][n]; this.n = n; } void expand(int tn, int an) { double[][][] newSumPrediction = new double[tn][an][n]; for (int t = 0; t < sumPrediction.length; t++) { double[][] tSrc = sumPrediction[t]; double[][] tDes = newSumPrediction[t]; for (int a = 0; a < tSrc.length; a++) { System.arraycopy(tSrc[a], 0, tDes[a], 0, n); } } sumPrediction = newSumPrediction; } } private int bestNumTrees; private int bestBaggingIters; private double bestAlpha; private int numTrees; private int baggingIters; private double minAlpha; private SimpleMetric metric; /** * Constructor. */ public AdditiveGrovesLearner() { verbose = false; numTrees = 6; baggingIters = 60; minAlpha = 0.01; metric = new RMSE(); } /** * Returns the metric. * * @return the metric. */ public SimpleMetric getMetric() { return metric; } /** * Sets the metric. Currently only support {@link mltk.predictor.evaluation.RMSE RMSE} and * {@link mltk.predictor.evaluation.AUC AUC}. * * @param metric the metric. */ public void setMetric(SimpleMetric metric) { if (metric instanceof RMSE || metric instanceof AUC) { this.metric = metric; } } /** * Returns the best number of trees in a grove from latest run. * * @return the best number of trees in a grove from latest run. */ public int getBestNumTrees() { return bestNumTrees; } /** * Returns the best bagging iterations from latest run. * * @return the best bagging iterations from latest run. */ public int getBestBaggingIters() { return bestBaggingIters; } /** * Returns the best alpha from latest run. * * @return the best alpha from latest run. */ public double getBestAlpha() { return bestAlpha; } /** * Returns the number of trees in a grove. The number of trees may be adjusted during the training. * * @return the number of trees in a grove. */ public int getNumTrees() { return numTrees; } /** * Sets the number of trees in a grove. The number of trees may be adjusted during the training. * * @param numTrees the number of trees in a grove. */ public void setNumTrees(int numTrees) { this.numTrees = numTrees; } /** * Returns the minimum alpha. The minimum alpha may be adjusted during the training. * * @return the minimum alpha. */ public double getMinAlpha() { return minAlpha; } /** * Sets the minimum alpha. The minimum alpha may be adjusted during the training. * * @param minAlpha the minimum alpha. */ public void setMinAlpha(double minAlpha) { this.minAlpha = minAlpha; } /** * Returns the number of bagging iterations. The number of bagging iterations may be adjusted during the training. * * @return the number of bagging iterations. */ public int getBaggingIters() { return baggingIters; } /** * Sets the number of bagging iterations. The number of bagging iterations may be adjusted during the training. * * @param baggingIters the bagging iterations. */ public void setBaggingIters(int baggingIters) { this.baggingIters = baggingIters; } /** * Builds additive groves. * * @param trainSet the training set. * @param validSet the validation set. * @return a regressor. */ public AdditiveGroves buildRegressor(Instances trainSet, Instances validSet) { int bn = baggingIters; int tn = numTrees; int an = 6; List alphas = new ArrayList<>(); for (int a = 0; a < an; a++) { alphas.add(getAlpha(a)); } int prevBN = 0; int prevTN = 0; int prevAN = 0; // Backup targets double[] targetTrain = new double[trainSet.size()]; for (int i = 0; i < targetTrain.length; i++) { targetTrain[i] = trainSet.get(i).getTarget(); } double[] targetValid = new double[validSet.size()]; for (int i = 0; i < targetValid.length; i++) { targetValid[i] = validSet.get(i).getTarget(); } PerformanceMatrix perfMatrix = new PerformanceMatrix(tn, an, bn, metric); ModelMatrix modelMatrix = new ModelMatrix(tn, an); PredictionMatrix predMatrix = new PredictionMatrix(tn, an, validSet.size()); IntPair bestParams = null; for (;;) { // Running bagging iterations runLayeredTraining(trainSet, validSet, prevBN, bn, 0, tn, 0, an, alphas, perfMatrix, modelMatrix, predMatrix, targetTrain, targetValid); bestParams = perfMatrix.getBestParameters(); boolean converged = true; prevBN = bn; prevTN = tn; prevAN = an; // Expand alpha if (bestParams.v2 == an - 1 && alphas.get(an - 1) > 1.0 / trainSet.size()) { converged = false; an += 3; for (int a = prevAN; a < an; a++) { alphas.add(getAlpha(a)); } System.out.println(an); predMatrix.expand(tn, an); perfMatrix.expand(tn, an, bn); modelMatrix.expand(tn, an, bn); runLayeredTraining(trainSet, validSet, 0, bn, 0, tn, prevAN, an, alphas, perfMatrix, modelMatrix, predMatrix, targetTrain, targetValid); } // Expand number of trees if (bestParams.v1 == tn - 1) { converged = false; tn += 3; predMatrix.expand(tn, an); perfMatrix.expand(tn, an, bn); modelMatrix.expand(tn, an, bn); runLayeredTraining(trainSet, validSet, 0, bn, prevTN, tn, 0, an, alphas, perfMatrix, modelMatrix, predMatrix, targetTrain, targetValid); } // Expand bagging if (!perfMatrix.analyzeBagging(bestParams.v1, bestParams.v2)) { converged = false; bn += 40; predMatrix.expand(tn, an); perfMatrix.expand(tn, an, bn); modelMatrix.expand(tn, an, bn); } if (converged) { break; } } // Restore targets for (int i = 0; i < targetTrain.length; i++) { trainSet.get(i).setTarget(targetTrain[i]); } System.out.println("Best model:"); System.out.println("Alpha = " + alphas.get(bestParams.v2)); System.out.println("N = " + (bestParams.v1 + 1)); System.out.println("b = " + bn); bestBaggingIters = bn; bestNumTrees = bestParams.v1 + 1; bestAlpha = getAlpha(bestParams.v2); return modelMatrix.groves[bestParams.v1][bestParams.v2]; } /** * Builds additive groves using layered training. * * @param trainSet the training set. * @param baggingIters the number of bagging iterations. * @param numTrees the number of trees in a grove. * @param alpha the alpha. * @return a regressor. */ public AdditiveGroves runLayeredTraining(Instances trainSet, int baggingIters, int numTrees, double alpha) { final int n = trainSet.size(); // Backup targets double[] targetTrain = new double[trainSet.size()]; for (int i = 0; i < targetTrain.length; i++) { targetTrain[i] = trainSet.get(i).getTarget(); } int bn = baggingIters; int tn = numTrees; int an = getAlphaIdx(alpha, trainSet.size()) + 1; AdditiveGroves ag = new AdditiveGroves(); for (int b = 0; b < bn; b++) { // The most recent predictions for regression trees double[][] rtPreds = new double[tn][n]; // The most recent residuals double[] residualTrain = new double[n]; for (int i = 0; i < n; i++) { residualTrain[i] = targetTrain[i]; } if (verbose) { System.out.println("Iteration " + (b + 1) + " out of " + bn); } for (int a = 0; a < an; a++) { double currAlpha = getAlpha(a); if (verbose) { System.out.println("\tBuilding models with alpha = " + currAlpha); } RegressionTree[] grove = new RegressionTree[tn]; backfit(trainSet, currAlpha, grove, rtPreds, residualTrain); ag.groves.add(grove); } } // Restore targets for (int i = 0; i < targetTrain.length; i++) { trainSet.get(i).setTarget(targetTrain[i]); } return ag; } @Override public AdditiveGroves build(Instances instances) { Instances trainSet = new Instances(instances.getAttributes(), instances.getTargetAttribute()); Instances validSet = new Instances(instances.getAttributes(), instances.getTargetAttribute()); int nTrain = instances.size() / 5 * 4; for (int i = 0; i < nTrain; i++) { trainSet.add(instances.get(i)); } for (int i = nTrain; i < instances.size(); i++) { validSet.add(instances.get(i)); } return buildRegressor(trainSet, validSet); } protected double getAlpha(int an) { double alpha = 1; if (an % 3 == 0) { alpha = 5; } else if (an % 3 == 1) { alpha = 2; } for (int i = 0; i < an / 3 + 1; i++) { alpha /= 10; } return alpha; } protected int getAlphaIdx(double alpha, int n) { int idx = 0; double min = 1.0 / n; while (alpha < getAlpha(idx) && min < alpha) { idx++; } return idx; } protected void backfit(Instances trainSet, double alpha, RegressionTree[] grove, double[][] rtPreds, double[] residualTrain) { Map bagIndices = new HashMap<>(); List oobIndices = new ArrayList<>(); Sampling.createBootstrapSample(trainSet, bagIndices, oobIndices); Instances bag = new Instances(trainSet.getAttributes(), trainSet.getTargetAttribute(), bagIndices.size()); for (Integer idx : bagIndices.keySet()) { int weight = bagIndices.get(idx); Instance instance = trainSet.get(idx).clone(); instance.setWeight(weight); bag.add(instance); } RegressionTreeLearner rtLearner = new RegressionTreeLearner(); rtLearner.setConstructionMode(Mode.ALPHA_LIMITED); rtLearner.setAlpha(alpha); double prevRMSE = evalRMSE(oobIndices, residualTrain); for (;;) { for (int iter = 0; iter < grove.length; iter++) { int treeIdx = (iter + grove.length - 1) % grove.length; double[] treePreds = rtPreds[treeIdx]; for (int i = 0; i < residualTrain.length; i++) { residualTrain[i] += treePreds[i]; trainSet.get(i).setTarget(residualTrain[i]); } RegressionTree rt = rtLearner.build(bag); grove[treeIdx] = rt; for (int i = 0; i < residualTrain.length; i++) { double pred = rt.regress(trainSet.get(i)); treePreds[i] = pred; residualTrain[i] -= pred; } } double currRMSE = evalRMSE(oobIndices, residualTrain); if (currRMSE == 0 || (prevRMSE - currRMSE) / prevRMSE <= 0.002) { break; } else { prevRMSE = currRMSE; } } } protected double regress(RegressionTree[] trees, Instance instance) { double pred = 0; for (RegressionTree rt : trees) { pred += rt.regress(instance); } return pred; } protected void runLayeredTraining(Instances trainSet, Instances validSet, int bStart, int bEnd, int tStart, int tEnd, int aStart, int aEnd, List alphas, PerformanceMatrix perfMatrix, ModelMatrix modelMatrix, PredictionMatrix predMatrix, double[] targetTrain, double[] targetValid) { final int n = trainSet.size(); final int tLen = tEnd - tStart; double[] predictionValid = new double[validSet.size()]; for (int b = bStart; b < bEnd; b++) { // The most recent predictions for regression trees double[][][] rtPreds = new double[tLen][tEnd][n]; // The most recent residuals double[][] residualTrain = new double[tLen][n]; for (int t = 0; t < tLen; t++) { double[] residual = residualTrain[t]; for (int i = 0; i < n; i++) { residual[i] = targetTrain[i]; } } if (aStart != 0) { for (int t = 0; t < tEnd; t++) { RegressionTree[] grove = modelMatrix.groves[t][aStart - 1].groves.get(b); update(trainSet, grove, rtPreds, residualTrain, t); } } if (verbose) { System.out.println("Iteration " + (b + 1) + " out of " + bEnd); } for (int a = aStart; a < aEnd; a++) { double alpha = alphas.get(a); if (verbose) { System.out.println("\tBuilding models with alpha = " + alpha); } for (int t = tStart; t < tEnd; t++) { int numTrees = t + 1; int tIdx = t - tStart; RegressionTree[] grove = new RegressionTree[numTrees]; backfit(trainSet, alpha, grove, rtPreds[tIdx], residualTrain[tIdx]); modelMatrix.add(t, a, grove); // Update predictions double[] sumPredictionValid = predMatrix.sumPrediction[t][a]; for (int i = 0; i < sumPredictionValid.length; i++) { sumPredictionValid[i] += regress(grove, validSet.get(i)); predictionValid[i] = sumPredictionValid[i] / (b + 1); } perfMatrix.eval(t, a, b, predictionValid, targetValid); } } } } protected double evalRMSE(List indices, double[] residual) { double rmse = 0; for (Integer idx : indices) { double d = residual[idx]; rmse += d * d; } rmse = Math.sqrt(rmse / indices.size()); return rmse; } protected void update(Instances trainSet, RegressionTree[] grove, double[][][] rtPreds, double[][] residualTrain, int maxTN) { for (int t = 0; t < maxTN; t++) { RegressionTree rt = grove[t]; for (int i = 0; i < trainSet.size(); i++) { double pred = rt.regress(trainSet.get(i)); rtPreds[maxTN][t][i] = pred; residualTrain[maxTN][i] -= pred; } } } } ================================================ FILE: src/main/java/mltk/predictor/tree/ensemble/ag/package-info.java ================================================ /** * Provides algorithms for fitting additive groves (AGs). */ package mltk.predictor.tree.ensemble.ag; ================================================ FILE: src/main/java/mltk/predictor/tree/ensemble/brt/BDT.java ================================================ package mltk.predictor.tree.ensemble.brt; import java.io.BufferedReader; import java.io.PrintWriter; import mltk.core.Instance; import mltk.predictor.ProbabilisticClassifier; import mltk.predictor.Regressor; import mltk.predictor.tree.ensemble.BoostedDTables; import mltk.util.MathUtils; import mltk.util.StatUtils; import mltk.util.VectorUtils; /** * Class for boosted decision tables (BDTs). * *

* Reference:
* Y. Lou and M. Obukhov. BDT: Boosting Decision Tables for High Accuracy and Scoring Efficiency. In Proceedings of the * 23rd ACM SIGKDD International Conference on Knowledge Discovery and Data Mining (KDD), Halifax, Nova Scotia, Canada, 2017. *

* * @author Yin Lou * */ public class BDT implements ProbabilisticClassifier, Regressor { protected BoostedDTables[] tables; /** * Constructs a BDT from a BRT object. * * @param brt the BRT object. * @return a BDT object. */ public static BDT constructBDT(BRT brt) { int k = brt.trees.length; BDT bdt = new BDT(); bdt.tables = new BoostedDTables[k]; for (int i = 0; i < bdt.tables.length; i++) { bdt.tables[i] = new BoostedDTables(brt.trees[i]); } return bdt; } /** * Constructor. */ public BDT() { } /** * Constructor. * * @param k the number of classes. */ public BDT(int k) { tables = new BoostedDTables[k]; for (int i = 0; i < tables.length; i++) { tables[i] = new BoostedDTables(); } } /** * Returns the table list for class k. * * @param k the class k. * @return the table list for class k. */ public BoostedDTables getDecisionTreeList(int k) { return tables[k]; } @Override public int classify(Instance instance) { double[] prob = predictProbabilities(instance); return StatUtils.indexOfMax(prob); } @Override public void read(BufferedReader in) throws Exception { int k = Integer.parseInt(in.readLine().split(": ")[1]); tables = new BoostedDTables[k]; for (int i = 0; i < tables.length; i++) { in.readLine(); tables[i] = new BoostedDTables(); tables[i].read(in); in.readLine(); } } @Override public void write(PrintWriter out) throws Exception { out.printf("[Predictor: %s]\n", this.getClass().getCanonicalName()); out.println("K: " + tables.length); for (BoostedDTables dtList : tables) { dtList.write(out); out.println(); } } @Override public double regress(Instance instance) { return tables[0].regress(instance); } @Override public double[] predictProbabilities(Instance instance) { if (tables.length == 1) { double[] prob = new double[2]; double pred = regress(instance); prob[1] = MathUtils.sigmoid(pred); prob[0] = 1 - prob[1]; return prob; } else { double[] prob = new double[tables.length]; double[] pred = new double[tables.length]; for (int i = 0; i < tables.length; i++) { pred[i] = tables[i].regress(instance); } double max = StatUtils.max(pred); double sum = 0; for (int i = 0; i < prob.length; i++) { prob[i] = Math.exp(pred[i] - max); sum += prob[i]; } VectorUtils.divide(prob, sum); return prob; } } @Override public BDT copy() { BDT copy = new BDT(tables.length); for (int i = 0; i < tables.length; i++) { copy.tables[i] = tables[i].copy(); } return copy; } } ================================================ FILE: src/main/java/mltk/predictor/tree/ensemble/brt/BRT.java ================================================ package mltk.predictor.tree.ensemble.brt; import java.io.BufferedReader; import java.io.PrintWriter; import mltk.core.Instance; import mltk.predictor.ProbabilisticClassifier; import mltk.predictor.Regressor; import mltk.predictor.tree.RTree; import mltk.predictor.tree.ensemble.BoostedRTrees; import mltk.util.MathUtils; import mltk.util.StatUtils; import mltk.util.VectorUtils; /** * Class for boosted regression trees (BRTs). * * @author Yin Lou * */ public class BRT implements ProbabilisticClassifier, Regressor { protected BoostedRTrees[] trees; /** * Constructor. */ public BRT() { } /** * Constructor. * * @param k the number of classes. */ public BRT(int k) { trees = new BoostedRTrees[k]; for (int i = 0; i < trees.length; i++) { trees[i] = new BoostedRTrees(); } } /** * Returns the tree list for class k. * * @param k the class k. * @return the tree list for class k. */ public BoostedRTrees getRegressionTreeList(int k) { return trees[k]; } @Override public int classify(Instance instance) { double[] prob = predictProbabilities(instance); return StatUtils.indexOfMax(prob); } @Override public void read(BufferedReader in) throws Exception { int k = Integer.parseInt(in.readLine().split(": ")[1]); trees = new BoostedRTrees[k]; for (int i = 0; i < trees.length; i++) { in.readLine(); trees[i] = new BoostedRTrees(); trees[i].read(in); in.readLine(); } } @Override public void write(PrintWriter out) throws Exception { out.printf("[Predictor: %s]\n", this.getClass().getCanonicalName()); out.println("K: " + trees.length); for (BoostedRTrees rtList : trees) { rtList.write(out); out.println(); } } @Override public double regress(Instance instance) { return trees[0].regress(instance); } @Override public double[] predictProbabilities(Instance instance) { if (trees.length == 1) { double[] prob = new double[2]; double pred = regress(instance); prob[1] = MathUtils.sigmoid(pred); prob[0] = 1 - prob[1]; return prob; } else { double[] prob = new double[trees.length]; double[] pred = new double[trees.length]; for (int i = 0; i < trees.length; i++) { pred[i] = trees[i].regress(instance); } double max = StatUtils.max(pred); double sum = 0; for (int i = 0; i < prob.length; i++) { prob[i] = Math.exp(pred[i] - max); sum += prob[i]; } VectorUtils.divide(prob, sum); return prob; } } @Override public BRT copy() { BRT copy = new BRT(trees.length); for (int i = 0; i < trees.length; i++) { BoostedRTrees brts = trees[i]; for (RTree rt : brts) { copy.trees[i].add(rt.copy()); } } return copy; } } ================================================ FILE: src/main/java/mltk/predictor/tree/ensemble/brt/BRTLearner.java ================================================ package mltk.predictor.tree.ensemble.brt; import mltk.predictor.tree.RegressionTreeLearner.Mode; import mltk.predictor.tree.ensemble.TreeEnsembleLearner; /** * Abstract class for boosted regression tree learner. * * @author Yin Lou * */ public abstract class BRTLearner extends TreeEnsembleLearner { protected int maxNumIters; protected double alpha; protected double learningRate; /** * Constructor. */ public BRTLearner() { verbose = false; maxNumIters = 3500; learningRate = 0.01; alpha = 1; RobustRegressionTreeLearner rtLearner = new RobustRegressionTreeLearner(); rtLearner.setConstructionMode(Mode.NUM_LEAVES_LIMITED); rtLearner.setMaxNumLeaves(100); treeLearner = rtLearner; } /** * Returns the alpha. * * @return the alpha. */ public double getAlpha() { return alpha; } /** * Sets the alpha. * * @param alpha the parameter that controls the portion of the features to consider * for each boosting iteration. */ public void setAlpha(double alpha) { this.alpha = alpha; } /** * Returns the learning rate. * * @return the learning rate. */ public double getLearningRate() { return learningRate; } /** * Sets the learning rate. * * @param learningRate the learning rate. */ public void setLearningRate(double learningRate) { this.learningRate = learningRate; } /** * Returns the maximum number of iterations. * * @return the maximum number of iterations. */ public int getMaxNumIters() { return maxNumIters; } /** * Sets the maximum number of iterations. * * @param maxNumIters the maximum number of iterations. */ public void setMaxNumIters(int maxNumIters) { this.maxNumIters = maxNumIters; } } ================================================ FILE: src/main/java/mltk/predictor/tree/ensemble/brt/BRTUtils.java ================================================ package mltk.predictor.tree.ensemble.brt; import mltk.predictor.tree.TreeLearner; import mltk.predictor.tree.DecisionTableLearner; import mltk.predictor.tree.RegressionTreeLearner; class BRTUtils { public static TreeLearner parseTreeLearner(String baseLearner) { String[] data = baseLearner.split(":"); if (data.length != 3) { throw new IllegalArgumentException(); } TreeLearner rtLearner = null; switch(data[0]) { case "rt": rtLearner = new RegressionTreeLearner(); break; case "rrt": rtLearner = new RobustRegressionTreeLearner(); break; case "dt": rtLearner = new DecisionTableLearner(); break; case "rdt": rtLearner = new RobustDecisionTableLearner(); break; default: System.err.println("Unknown regression tree learner: " + data[0]); throw new IllegalArgumentException(); } rtLearner.setParameters(data[1] + ":" + data[2]); return rtLearner; } } ================================================ FILE: src/main/java/mltk/predictor/tree/ensemble/brt/LADBoostLearner.java ================================================ package mltk.predictor.tree.ensemble.brt; import java.util.ArrayList; import java.util.Arrays; import java.util.HashMap; import java.util.List; import java.util.Map; import mltk.cmdline.Argument; import mltk.cmdline.CmdLineParser; import mltk.cmdline.options.HoldoutValidatedLearnerOptions; import mltk.core.Attribute; import mltk.core.Instances; import mltk.core.io.InstancesReader; import mltk.predictor.evaluation.ConvergenceTester; import mltk.predictor.evaluation.MAE; import mltk.predictor.evaluation.Metric; import mltk.predictor.evaluation.MetricFactory; import mltk.predictor.evaluation.SimpleMetric; import mltk.predictor.io.PredictorWriter; import mltk.predictor.tree.RegressionTree; import mltk.predictor.tree.RegressionTreeLeaf; import mltk.predictor.tree.TreeLearner; import mltk.util.ArrayUtils; import mltk.util.MathUtils; import mltk.util.Permutation; import mltk.util.Random; /** * Class for least-absolute-deviation boost learner. * * @author Yin Lou * */ public class LADBoostLearner extends BRTLearner { static class Options extends HoldoutValidatedLearnerOptions { @Argument(name = "-b", description = "base learner (tree:mode:parameter) (default: rt:l:100)") String baseLearner = "rt:l:100"; @Argument(name = "-m", description = "maximum number of iterations", required = true) int maxNumIters = -1; @Argument(name = "-s", description = "seed of the random number generator (default: 0)") long seed = 0L; @Argument(name = "-l", description = "learning rate (default: 0.01)") double learningRate = 0.01; } /** * Trains a boosted tree ensemble using least-absolute-deviation as the objective function. * *
	 * Usage: mltk.predictor.tree.ensemble.brt.LADBoostLearner
	 * -t	train set path
	 * -m	maximum number of iterations
	 * [-v]	valid set path
	 * [-e]	evaluation metric (default: default metric of task)
	 * [-S]	convergence criteria (default: -1)
	 * [-r]	attribute file path
	 * [-o]	output model path
	 * [-V]	verbose (default: true)
	 * [-b]	base learner (tree:mode:parameter) (default: rt:l:100)
	 * [-s]	seed of the random number generator (default: 0)
	 * [-l]	learning rate (default: 0.01)
	 * 
* * @param args the command line arguments. * @throws Exception */ public static void main(String[] args) throws Exception { Options opts = new Options(); CmdLineParser parser = new CmdLineParser(LADBoostLearner.class, opts); Metric metric = null; TreeLearner rtLearner = null; try { parser.parse(args); if (opts.metric == null) { metric = new MAE(); } else { metric = MetricFactory.getMetric(opts.metric); } rtLearner = BRTUtils.parseTreeLearner(opts.baseLearner); } catch (IllegalArgumentException e) { parser.printUsage(); System.exit(1); } Random.getInstance().setSeed(opts.seed); ConvergenceTester ct = ConvergenceTester.parse(opts.cc); Instances trainSet = InstancesReader.read(opts.attPath, opts.trainPath); LADBoostLearner learner = new LADBoostLearner(); learner.setLearningRate(opts.learningRate); learner.setMaxNumIters(opts.maxNumIters); learner.setVerbose(opts.verbose); learner.setMetric(metric); learner.setTreeLearner(rtLearner); learner.setConvergenceTester(ct); if (opts.validPath != null) { Instances validSet = InstancesReader.read(opts.attPath, opts.validPath); learner.setValidSet(validSet); } long start = System.currentTimeMillis(); BRT brt = learner.build(trainSet); long end = System.currentTimeMillis(); System.out.println("Time: " + (end - start) / 1000.0); if (opts.outputModelPath != null) { PredictorWriter.write(brt, opts.outputModelPath); } } /** * Constructor. */ public LADBoostLearner() { } @Override public BRT build(Instances instances) { if (metric == null) { metric = new MAE(); } if (validSet != null) { return buildRegressor(instances, validSet, maxNumIters); } else { return buildRegressor(instances, maxNumIters); } } /** * Builds a regressor. * * @param trainSet the training set. * @param validSet the validation set. * @param maxNumIters the maximum number of iterations. * @return a regressor. */ public BRT buildRegressor(Instances trainSet, Instances validSet, int maxNumIters) { BRT brt = new BRT(1); treeLearner.cache(trainSet); List attributes = trainSet.getAttributes(); int limit = (int) (attributes.size() * alpha); int[] indices = new int[limit]; Permutation perm = new Permutation(attributes.size()); perm.permute(); // Backup targets double[] target = new double[trainSet.size()]; for (int i = 0; i < target.length; i++) { target[i] = trainSet.get(i).getTarget(); } double intercept = ArrayUtils.getMedian(target); RegressionTree initialTree = new RegressionTree(new RegressionTreeLeaf(intercept)); brt.trees[0].add(initialTree); double[] rTrain = new double[trainSet.size()]; for (int i = 0; i < rTrain.length; i++) { rTrain[i] = target[i] - intercept; } double[] pValid = new double[validSet.size()]; Arrays.fill(pValid, intercept); // Resets the convergence tester ct.setMetric(metric); for (int iter = 0; iter < maxNumIters; iter++) { // Prepare attributes if (alpha < 1) { int[] a = perm.getPermutation(); for (int i = 0; i < indices.length; i++) { indices[i] = a[i]; } Arrays.sort(indices); List attList = trainSet.getAttributes(indices); trainSet.setAttributes(attList); } // Prepare training set for (int i = 0; i < rTrain.length; i++) { trainSet.get(i).setTarget(MathUtils.sign(rTrain[i])); } RegressionTree rt = (RegressionTree) treeLearner.build(trainSet); brt.trees[0].add(rt); if (alpha < 1) { // Restore attributes trainSet.setAttributes(attributes); } // Replace the leaf value by median Map> map = new HashMap<>(); for (int i = 0; i < rTrain.length; i++) { RegressionTreeLeaf leaf = rt.getLeafNode(trainSet.get(i)); if (!map.containsKey(leaf)) { map.put(leaf, new ArrayList()); } map.get(leaf).add(i); } for (Map.Entry> entry : map.entrySet()) { RegressionTreeLeaf leaf = entry.getKey(); List list = entry.getValue(); double[] values = new double[list.size()]; for (int i = 0; i < values.length; i++) { values[i] = rTrain[list.get(i)]; } double pred = ArrayUtils.getMedian(values) * learningRate; leaf.setPrediction(pred); } // Update predictions and residuals for (int i = 0; i < rTrain.length; i++) { double pred = rt.regress(trainSet.get(i)); rTrain[i] -= pred; } for (int i = 0; i < pValid.length; i++) { double pred = rt.regress(validSet.get(i)); pValid[i] += pred; } double measure = metric.eval(pValid, validSet); ct.add(measure); if (verbose) { System.out.println("Iteration " + iter + ": " + measure); } if (ct.isConverged()) { break; } } // Search the best model on validation set int idx = ct.getBestIndex(); for (int i = brt.trees[0].size() - 1; i > idx; i--) { brt.trees[0].removeLast(); } // Restore targets for (int i = 0; i < target.length; i++) { trainSet.get(i).setTarget(target[i]); } treeLearner.evictCache(); return brt; } /** * Builds a regressor. * * @param trainSet the training set. * @param maxNumIters the maximum number of iterations. * @return a regressor. */ public BRT buildRegressor(Instances trainSet, int maxNumIters) { BRT brt = new BRT(1); treeLearner.cache(trainSet); SimpleMetric simpleMetric = (SimpleMetric) metric; List attributes = trainSet.getAttributes(); int limit = (int) (attributes.size() * alpha); int[] indices = new int[limit]; Permutation perm = new Permutation(attributes.size()); perm.permute(); // Backup targets double[] target = new double[trainSet.size()]; for (int i = 0; i < target.length; i++) { target[i] = trainSet.get(i).getTarget(); } double intercept = ArrayUtils.getMedian(target); RegressionTree initialTree = new RegressionTree(new RegressionTreeLeaf(intercept)); brt.trees[0].add(initialTree); double[] pTrain = new double[trainSet.size()]; double[] rTrain = new double[trainSet.size()]; for (int i = 0; i < rTrain.length; i++) { pTrain[i] = intercept; rTrain[i] = target[i] - intercept; } for (int iter = 0; iter < maxNumIters; iter++) { // Prepare attributes if (alpha < 1) { int[] a = perm.getPermutation(); for (int i = 0; i < indices.length; i++) { indices[i] = a[i]; } Arrays.sort(indices); List attList = trainSet.getAttributes(indices); trainSet.setAttributes(attList); } // Prepare training set for (int i = 0; i < rTrain.length; i++) { trainSet.get(i).setTarget(MathUtils.sign(rTrain[i])); } RegressionTree rt = (RegressionTree) treeLearner.build(trainSet); brt.trees[0].add(rt); if (alpha < 1) { // Restore attributes trainSet.setAttributes(attributes); } // Replace the leaf value by median Map> map = new HashMap<>(); for (int i = 0; i < rTrain.length; i++) { RegressionTreeLeaf leaf = rt.getLeafNode(trainSet.get(i)); if (!map.containsKey(leaf)) { map.put(leaf, new ArrayList()); } map.get(leaf).add(i); } for (Map.Entry> entry : map.entrySet()) { RegressionTreeLeaf leaf = entry.getKey(); List list = entry.getValue(); double[] values = new double[list.size()]; for (int i = 0; i < values.length; i++) { values[i] = rTrain[list.get(i)]; } double pred = ArrayUtils.getMedian(values) * learningRate; leaf.setPrediction(pred); } // Update residuals for (int i = 0; i < rTrain.length; i++) { double pred = rt.regress(trainSet.get(i)); pTrain[i] += pred; rTrain[i] -= pred; } if (verbose) { double measure = simpleMetric.eval(pTrain, target); System.out.println("Iteration " + iter + ": " + measure); } } // Restore targets for (int i = 0; i < target.length; i++) { trainSet.get(i).setTarget(target[i]); } treeLearner.evictCache(); return brt; } } ================================================ FILE: src/main/java/mltk/predictor/tree/ensemble/brt/LSBoostLearner.java ================================================ package mltk.predictor.tree.ensemble.brt; import java.util.Arrays; import java.util.List; import mltk.cmdline.Argument; import mltk.cmdline.CmdLineParser; import mltk.cmdline.options.HoldoutValidatedLearnerOptions; import mltk.core.Attribute; import mltk.core.Instances; import mltk.core.io.InstancesReader; import mltk.predictor.evaluation.ConvergenceTester; import mltk.predictor.evaluation.Metric; import mltk.predictor.evaluation.MetricFactory; import mltk.predictor.evaluation.RMSE; import mltk.predictor.evaluation.SimpleMetric; import mltk.predictor.io.PredictorWriter; import mltk.predictor.tree.RTree; import mltk.predictor.tree.TreeLearner; import mltk.util.Permutation; import mltk.util.Random; /** * Class for least-squares boost learner. * * @author Yin Lou * */ public class LSBoostLearner extends BRTLearner { static class Options extends HoldoutValidatedLearnerOptions { @Argument(name = "-b", description = "base learner (tree:mode:parameter) (default: rt:l:100)") String baseLearner = "rt:l:100"; @Argument(name = "-m", description = "maximum number of iterations", required = true) int maxNumIters = -1; @Argument(name = "-s", description = "seed of the random number generator (default: 0)") long seed = 0L; @Argument(name = "-l", description = "learning rate (default: 0.01)") double learningRate = 0.01; } /** * Trains a boosted tree ensemble using least square as the objective function. * *
	 * Usage: mltk.predictor.tree.ensemble.brt.LSBoostLearner
	 * -t	train set path
	 * -m	maximum number of iterations
	 * [-v]	valid set path
	 * [-e]	evaluation metric (default: default metric of task)
	 * [-S]	convergence criteria (default: -1) 
	 * [-r]	attribute file path
	 * [-o]	output model path
	 * [-V]	verbose (default: true)
	 * [-b]	base learner (tree:mode:parameter) (default: rt:l:100)
	 * [-s]	seed of the random number generator (default: 0)
	 * [-l]	learning rate (default: 0.01)
	 * 
* * @param args the command line arguments. * @throws Exception */ public static void main(String[] args) throws Exception { Options opts = new Options(); CmdLineParser parser = new CmdLineParser(LSBoostLearner.class, opts); Metric metric = null; TreeLearner rtLearner = null; try { parser.parse(args); if (opts.metric == null) { metric = new RMSE(); } else { metric = MetricFactory.getMetric(opts.metric); } rtLearner = BRTUtils.parseTreeLearner(opts.baseLearner); } catch (IllegalArgumentException e) { parser.printUsage(); System.exit(1); } Random.getInstance().setSeed(opts.seed); ConvergenceTester ct = ConvergenceTester.parse(opts.cc); Instances trainSet = InstancesReader.read(opts.attPath, opts.trainPath); LSBoostLearner learner = new LSBoostLearner(); learner.setLearningRate(opts.learningRate); learner.setMaxNumIters(opts.maxNumIters); learner.setVerbose(opts.verbose); learner.setMetric(metric); learner.setTreeLearner(rtLearner); learner.setConvergenceTester(ct); if (opts.validPath != null) { Instances validSet = InstancesReader.read(opts.attPath, opts.validPath); learner.setValidSet(validSet); } long start = System.currentTimeMillis(); BRT brt = learner.build(trainSet); long end = System.currentTimeMillis(); System.out.println("Time: " + (end - start) / 1000.0); if (opts.outputModelPath != null) { PredictorWriter.write(brt, opts.outputModelPath); } } /** * Constructor. */ public LSBoostLearner() { } @Override public BRT build(Instances instances) { if (metric == null) { metric = new RMSE(); } if (validSet != null) { return buildRegressor(instances, validSet, maxNumIters); } else { return buildRegressor(instances, maxNumIters); } } /** * Builds a regressor. * * @param trainSet the training set. * @param validSet the validation set. * @param maxNumIters the maximum number of iterations. * @return a regressor. */ public BRT buildRegressor(Instances trainSet, Instances validSet, int maxNumIters) { BRT brt = new BRT(1); treeLearner.cache(trainSet); List attributes = trainSet.getAttributes(); int limit = (int) (attributes.size() * alpha); int[] indices = new int[limit]; Permutation perm = new Permutation(attributes.size()); perm.permute(); // Backup targets double[] target = new double[trainSet.size()]; for (int i = 0; i < target.length; i++) { target[i] = trainSet.get(i).getTarget(); } double[] rTrain = new double[trainSet.size()]; for (int i = 0; i < rTrain.length; i++) { rTrain[i] = target[i]; } double[] pValid = new double[validSet.size()]; // Resets the convergence tester ct.setMetric(metric); for (int iter = 0; iter < maxNumIters; iter++) { // Prepare training set if (alpha < 1) { int[] a = perm.getPermutation(); for (int i = 0; i < indices.length; i++) { indices[i] = a[i]; } Arrays.sort(indices); List attList = trainSet.getAttributes(indices); trainSet.setAttributes(attList); } // Prepare training set for (int i = 0; i < rTrain.length; i++) { trainSet.get(i).setTarget(rTrain[i]); } RTree rt = (RTree) treeLearner.build(trainSet); if (learningRate != 1) { rt.multiply(learningRate); } brt.trees[0].add(rt); if (alpha < 1) { // Restore attributes trainSet.setAttributes(attributes); } // Update predictions and residuals for (int i = 0; i < rTrain.length; i++) { double pred = rt.regress(trainSet.get(i)); rTrain[i] -= pred; } for (int i = 0; i < pValid.length; i++) { double pred = rt.regress(validSet.get(i)); pValid[i] += pred; } double measure = metric.eval(pValid, validSet); ct.add(measure); if (verbose) { System.out.println("Iteration " + iter + ": " + measure); } if (ct.isConverged()) { break; } } // Search the best model on validation set int idx = ct.getBestIndex(); for (int i = brt.trees[0].size() - 1; i > idx; i--) { brt.trees[0].removeLast(); } // Restore targets for (int i = 0; i < target.length; i++) { trainSet.get(i).setTarget(target[i]); } treeLearner.evictCache(); return brt; } /** * Builds a regressor. * * @param trainSet the training set. * @param maxNumIters the maximum number of iterations. * @return a regressor. */ public BRT buildRegressor(Instances trainSet, int maxNumIters) { BRT brt = new BRT(1); treeLearner.cache(trainSet); SimpleMetric simpleMetric = (SimpleMetric) metric; List attributes = trainSet.getAttributes(); int limit = (int) (attributes.size() * alpha); int[] indices = new int[limit]; Permutation perm = new Permutation(attributes.size()); perm.permute(); // Backup targets double[] target = new double[trainSet.size()]; for (int i = 0; i < target.length; i++) { target[i] = trainSet.get(i).getTarget(); } double[] pTrain = new double[trainSet.size()]; double[] rTrain = new double[trainSet.size()]; for (int i = 0; i < rTrain.length; i++) { rTrain[i] = target[i]; } for (int iter = 0; iter < maxNumIters; iter++) { // Prepare training set if (alpha < 1) { int[] a = perm.getPermutation(); for (int i = 0; i < indices.length; i++) { indices[i] = a[i]; } Arrays.sort(indices); List attList = trainSet.getAttributes(indices); trainSet.setAttributes(attList); } for (int i = 0; i < rTrain.length; i++) { trainSet.get(i).setTarget(rTrain[i]); } RTree rt = (RTree) treeLearner.build(trainSet); if (learningRate != 1) { rt.multiply(learningRate); } brt.trees[0].add(rt); if (alpha < 1) { // Restore attributes trainSet.setAttributes(attributes); } // Update residuals for (int i = 0; i < rTrain.length; i++) { double pred = rt.regress(trainSet.get(i)); pTrain[i] += pred; rTrain[i] -= pred; } if (verbose) { double measure = simpleMetric.eval(pTrain, target); System.out.println("Iteration " + iter + ": " + measure); } } // Restore targets for (int i = 0; i < target.length; i++) { trainSet.get(i).setTarget(target[i]); } treeLearner.evictCache(); return brt; } } ================================================ FILE: src/main/java/mltk/predictor/tree/ensemble/brt/LogitBoostLearner.java ================================================ package mltk.predictor.tree.ensemble.brt; import java.util.ArrayList; import java.util.Arrays; import java.util.List; import mltk.cmdline.Argument; import mltk.cmdline.CmdLineParser; import mltk.cmdline.options.HoldoutValidatedLearnerOptions; import mltk.core.Attribute; import mltk.core.Instance; import mltk.core.Instances; import mltk.core.NominalAttribute; import mltk.core.io.InstancesReader; import mltk.predictor.evaluation.ConvergenceTester; import mltk.predictor.evaluation.Error; import mltk.predictor.evaluation.Metric; import mltk.predictor.evaluation.MetricFactory; import mltk.predictor.evaluation.SimpleMetric; import mltk.predictor.io.PredictorWriter; import mltk.predictor.tree.RTree; import mltk.predictor.tree.TreeLearner; import mltk.util.MathUtils; import mltk.util.OptimUtils; import mltk.util.Permutation; import mltk.util.Random; /** * Class for logit boost learner. * *

* Reference:
* P. Li. Robust logitboost and adaptive base class (abc) logitboost. In Proceedings of the 26th Conference on * Uncertainty in Artificial Intelligence (UAI), Catalina Island, CA, USA, 2010. *

* * @author Yin Lou * */ public class LogitBoostLearner extends BRTLearner { static class Options extends HoldoutValidatedLearnerOptions { @Argument(name = "-b", description = "base learner (tree:mode:parameter) (default: rt:l:100)") String baseLearner = "rt:l:100"; @Argument(name = "-m", description = "maximum number of iterations", required = true) int maxNumIters = -1; @Argument(name = "-s", description = "seed of the random number generator (default: 0)") long seed = 0L; @Argument(name = "-l", description = "learning rate (default: 0.01)") double learningRate = 0.01; } /** * Trains an additive logistic regression. * *
	 * Usage: mltk.predictor.tree.ensemble.brt.LogitBoostLearner
	 * -t	train set path
	 * -m	maximum number of iterations
	 * [-v]	valid set path
	 * [-e]	evaluation metric (default: default metric of task)
	 * [-S]	convergence criteria (default: -1)
	 * [-r]	attribute file path
	 * [-o]	output model path
	 * [-V]	verbose (default: true)
	 * [-b]	base learner (tree:mode:parameter) (default: rt:l:100)
	 * [-s]	seed of the random number generator (default: 0)
	 * [-l]	learning rate (default: 0.01)
	 * 
* * @param args the command line arguments. * @throws Exception */ public static void main(String[] args) throws Exception { Options opts = new Options(); CmdLineParser parser = new CmdLineParser(LogitBoostLearner.class, opts); Metric metric = null; TreeLearner rtLearner = null; try { parser.parse(args); if (opts.metric == null) { metric = new Error(); } else { metric = MetricFactory.getMetric(opts.metric); } // Using robust version of the base tree learner opts.baseLearner = "r" + opts.baseLearner; rtLearner = BRTUtils.parseTreeLearner(opts.baseLearner); } catch (IllegalArgumentException e) { parser.printUsage(); System.exit(1); } Random.getInstance().setSeed(opts.seed); ConvergenceTester ct = ConvergenceTester.parse(opts.cc); Instances trainSet = InstancesReader.read(opts.attPath, opts.trainPath); LogitBoostLearner learner = new LogitBoostLearner(); learner.setLearningRate(opts.learningRate); learner.setMaxNumIters(opts.maxNumIters); learner.setVerbose(opts.verbose); learner.setMetric(metric); learner.setTreeLearner(rtLearner); learner.setConvergenceTester(ct); if (opts.validPath != null) { Instances validSet = InstancesReader.read(opts.attPath, opts.validPath); learner.setValidSet(validSet); } long start = System.currentTimeMillis(); BRT brt = learner.build(trainSet); long end = System.currentTimeMillis(); System.out.println("Time: " + (end - start) / 1000.0 + " (s)."); if (opts.outputModelPath != null) { PredictorWriter.write(brt, opts.outputModelPath); } } /** * Constructor. */ public LogitBoostLearner() { } @Override public BRT build(Instances instances) { if (metric == null) { metric = new Error(); } if (validSet != null) { return buildClassifier(instances, validSet, maxNumIters); } else { return buildClassifier(instances, maxNumIters); } } /** * Builds a classifier. * * @param trainSet the training set. * @param validSet the validation set. * @param maxNumIters the maximum number of iterations. * @return a classifier. */ public BRT buildBinaryClassifier(Instances trainSet, Instances validSet, int maxNumIters) { Attribute classAttribute = trainSet.getTargetAttribute(); if (classAttribute.getType() != Attribute.Type.NOMINAL) { throw new IllegalArgumentException("Class attribute must be nominal."); } NominalAttribute clazz = (NominalAttribute) classAttribute; if (clazz.getCardinality() != 2) { throw new IllegalArgumentException("Only binary classification is accepted."); } BRT brt = new BRT(1); treeLearner.cache(trainSet); List attributes = trainSet.getAttributes(); int limit = (int) (attributes.size() * alpha); int[] indices = new int[limit]; Permutation perm = new Permutation(attributes.size()); if (alpha < 1) { perm.permute(); } // Backup targets and weights double[] targetTrain = new double[trainSet.size()]; double[] weightTrain = new double[targetTrain.length]; for (int i = 0; i < targetTrain.length; i++) { Instance instance = trainSet.get(i); targetTrain[i] = instance.getTarget(); weightTrain[i] = instance.getWeight(); } // Initialization double[] predTrain = new double[targetTrain.length]; double[] probTrain = new double[targetTrain.length]; computeProbabilities(predTrain, probTrain); double[] rTrain = new double[targetTrain.length]; OptimUtils.computePseudoResidual(predTrain, targetTrain, rTrain); double[] predValid = new double[validSet.size()]; // Resets the convergence tester ct.setMetric(metric); for (int iter = 0; iter < maxNumIters; iter++) { // Prepare attributes if (alpha < 1) { int[] a = perm.getPermutation(); for (int i = 0; i < indices.length; i++) { indices[i] = a[i]; } Arrays.sort(indices); List attList = trainSet.getAttributes(indices); trainSet.setAttributes(attList); } // Prepare training set for (int i = 0; i < targetTrain.length; i++) { Instance instance = trainSet.get(i); double prob = probTrain[i]; double w = prob * (1 - prob); instance.setTarget(rTrain[i] * weightTrain[i]); instance.setWeight(w * weightTrain[i]); } RTree rt = (RTree) treeLearner.build(trainSet); if (learningRate != 1) { rt.multiply(learningRate); } brt.trees[0].add(rt); for (int i = 0; i < predTrain.length; i++) { double pred = rt.regress(trainSet.get(i)); predTrain[i] += pred; } for (int i = 0; i < predValid.length; i++) { double pred = rt.regress(validSet.get(i)); predValid[i] += pred; } if (alpha < 1) { // Restore attributes trainSet.setAttributes(attributes); } // Update residuals and probabilities OptimUtils.computePseudoResidual(predTrain, targetTrain, rTrain); computeProbabilities(predTrain, probTrain); double measure = metric.eval(predValid, validSet); ct.add(measure); if (verbose) { System.out.println("Iteration " + iter + ": " + measure); } if (ct.isConverged()) { break; } } // Search the best model on validation set int idx = ct.getBestIndex(); for (int i = brt.trees[0].size() - 1; i > idx; i--) { brt.trees[0].removeLast(); } // Restore targets and weights for (int i = 0; i < targetTrain.length; i++) { Instance instance = trainSet.get(i); instance.setTarget(targetTrain[i]); instance.setWeight(weightTrain[i]); } treeLearner.evictCache(); return brt; } /** * Builds a classifier. * * @param trainSet the training set. * @param maxNumIters the maximum number of iterations. * @return a classifier. */ public BRT buildBinaryClassifier(Instances trainSet, int maxNumIters) { Attribute classAttribute = trainSet.getTargetAttribute(); if (classAttribute.getType() != Attribute.Type.NOMINAL) { throw new IllegalArgumentException("Class attribute must be nominal."); } NominalAttribute clazz = (NominalAttribute) classAttribute; if (clazz.getCardinality() != 2) { throw new IllegalArgumentException("Only binary classification is accepted."); } SimpleMetric simpleMetric = (SimpleMetric) metric; BRT brt = new BRT(1); treeLearner.cache(trainSet); List attributes = trainSet.getAttributes(); int limit = (int) (attributes.size() * alpha); int[] indices = new int[limit]; Permutation perm = new Permutation(attributes.size()); if (alpha < 1) { perm.permute(); } // Backup targets and weights double[] targetTrain = new double[trainSet.size()]; double[] weightTrain = new double[targetTrain.length]; for (int i = 0; i < targetTrain.length; i++) { Instance instance = trainSet.get(i); targetTrain[i] = instance.getTarget(); weightTrain[i] = instance.getWeight(); } // Initialization double[] predTrain = new double[targetTrain.length]; double[] probTrain = new double[targetTrain.length]; computeProbabilities(predTrain, probTrain); double[] rTrain = new double[targetTrain.length]; OptimUtils.computePseudoResidual(predTrain, targetTrain, rTrain); List measureList = new ArrayList<>(maxNumIters); for (int iter = 0; iter < maxNumIters; iter++) { // Prepare attributes if (alpha < 1) { int[] a = perm.getPermutation(); for (int i = 0; i < indices.length; i++) { indices[i] = a[i]; } Arrays.sort(indices); List attList = trainSet.getAttributes(indices); trainSet.setAttributes(attList); } // Prepare training set for (int i = 0; i < targetTrain.length; i++) { Instance instance = trainSet.get(i); double prob = probTrain[i]; double w = prob * (1 - prob); instance.setTarget(rTrain[i] * weightTrain[i]); instance.setWeight(w * weightTrain[i]); } RTree rt = (RTree) treeLearner.build(trainSet); if (learningRate != 1) { rt.multiply(learningRate); } brt.trees[0].add(rt); for (int i = 0; i < predTrain.length; i++) { double pred = rt.regress(trainSet.get(i)); predTrain[i] += pred; } if (alpha < 1) { // Restore attributes trainSet.setAttributes(attributes); } // Update residuals and probabilities OptimUtils.computePseudoResidual(predTrain, targetTrain, rTrain); computeProbabilities(predTrain, probTrain); double measure = simpleMetric.eval(predTrain, targetTrain); measureList.add(measure); if (verbose) { System.out.println("Iteration " + iter + ": " + measure); } } // Search the best model on validation set int idx = metric.searchBestMetricValueIndex(measureList); for (int i = brt.trees[0].size() - 1; i > idx; i--) { brt.trees[0].removeLast(); } // Restore targets and weights for (int i = 0; i < targetTrain.length; i++) { Instance instance = trainSet.get(i); instance.setTarget(targetTrain[i]); instance.setWeight(weightTrain[i]); } treeLearner.evictCache(); return brt; } /** * Builds a classifier. * * @param trainSet the training set. * @param validSet the validation set. * @param maxNumIters the maximum number of iterations. * @return a classifier. */ public BRT buildClassifier(Instances trainSet, Instances validSet, int maxNumIters) { Attribute classAttribute = trainSet.getTargetAttribute(); if (classAttribute.getType() != Attribute.Type.NOMINAL) { throw new IllegalArgumentException("Class attribute must be nominal."); } NominalAttribute clazz = (NominalAttribute) classAttribute; final int numClasses = clazz.getCardinality(); if (numClasses == 2) { return buildBinaryClassifier(trainSet, validSet, maxNumIters); } else { System.err.println("Multiclass LogitBoost, only use mis-classification rate as metric now"); final double l = learningRate * (numClasses - 1.0) / numClasses; BRT brt = new BRT(numClasses); treeLearner.cache(trainSet); List attributes = trainSet.getAttributes(); int limit = (int) (attributes.size() * alpha); int[] indices = new int[limit]; Permutation perm = new Permutation(attributes.size()); if (alpha < 1) { perm.permute(); } // Backup targets and weights double[] targetTrain = new double[trainSet.size()]; double[] weightTrain = new double[targetTrain.length]; for (int i = 0; i < targetTrain.length; i++) { Instance instance = trainSet.get(i); targetTrain[i] = instance.getTarget(); weightTrain[i] = instance.getWeight(); } double[] targetValid = new double[validSet.size()]; for (int i = 0; i < targetValid.length; i++) { targetValid[i] = validSet.get(i).getTarget(); } // Initialization double[][] predTrain = new double[numClasses][targetTrain.length]; double[][] probTrain = new double[numClasses][targetTrain.length]; int[][] rTrain = new int[numClasses][targetTrain.length]; for (int k = 0; k < numClasses; k++) { int[] rkTrain = rTrain[k]; double[] probkTrain = probTrain[k]; for (int i = 0; i < rkTrain.length; i++) { rkTrain[i] = MathUtils.indicator(targetTrain[i] == k); probkTrain[i] = 1.0 / numClasses; } } double[][] predValid = new double[numClasses][validSet.size()]; for (int iter = 0; iter < maxNumIters; iter++) { // Prepare attributes if (alpha < 1) { int[] a = perm.getPermutation(); for (int i = 0; i < indices.length; i++) { indices[i] = a[i]; } Arrays.sort(indices); List attList = trainSet.getAttributes(indices); trainSet.setAttributes(attList); } for (int k = 0; k < numClasses; k++) { // Prepare training set int[] rkTrain = rTrain[k]; double[] probkTrain = probTrain[k]; for (int i = 0; i < targetTrain.length; i++) { Instance instance = trainSet.get(i); double pk = probkTrain[i]; double t = rkTrain[i] - pk; double w = pk * (1 - pk); instance.setTarget(t * weightTrain[i]); instance.setWeight(w * weightTrain[i]); } RTree rt = (RTree) treeLearner.build(trainSet); rt.multiply(l); brt.trees[k].add(rt); double[] predkTrain = predTrain[k]; for (int i = 0; i < predkTrain.length; i++) { double p = rt.regress(trainSet.get(i)); predkTrain[i] += p; } double[] predkValid = predValid[k]; for (int i = 0; i < predkValid.length; i++) { double p = rt.regress(validSet.get(i)); predkValid[i] += p; } } if (alpha < 1) { // Restore attributes trainSet.setAttributes(attributes); } // Update probabilities computeProbabilities(predTrain, probTrain); if (verbose) { double error = 0; for (int i = 0; i < targetValid.length; i++) { double p = 0; double max = Double.NEGATIVE_INFINITY; for (int k = 0; k < numClasses; k++) { if (predValid[k][i] > max) { max = predValid[k][i]; p = k; } } if (p != targetValid[i]) { error++; } } error /= targetValid.length; System.out.println("Iteration " + iter + ": " + error); } } // Restore targets and weights for (int i = 0; i < targetTrain.length; i++) { Instance instance = trainSet.get(i); instance.setTarget(targetTrain[i]); instance.setWeight(weightTrain[i]); } treeLearner.evictCache(); return brt; } } /** * Builds a classifier. * * @param trainSet the training set. * @param maxNumIters the maximum number of iterations. * @return a classifier. */ public BRT buildClassifier(Instances trainSet, int maxNumIters) { Attribute classAttribute = trainSet.getTargetAttribute(); if (classAttribute.getType() != Attribute.Type.NOMINAL) { throw new IllegalArgumentException("Class attribute must be nominal."); } NominalAttribute clazz = (NominalAttribute) classAttribute; final int numClasses = clazz.getCardinality(); if (numClasses == 2) { return buildBinaryClassifier(trainSet, maxNumIters); } else { final int n = trainSet.size(); final double l = learningRate * (numClasses - 1.0) / numClasses; BRT brt = new BRT(numClasses); treeLearner.cache(trainSet); List attributes = trainSet.getAttributes(); int limit = (int) (attributes.size() * alpha); int[] indices = new int[limit]; Permutation perm = new Permutation(attributes.size()); if (alpha < 1) { perm.permute(); } // Backup targets and weights double[] target = new double[n]; double[] weight = new double[n]; for (int i = 0; i < n; i++) { Instance instance = trainSet.get(i); target[i] = instance.getTarget(); weight[i] = instance.getWeight(); } // Initialization double[][] predTrain = new double[numClasses][n]; double[][] probTrain = new double[numClasses][n]; int[][] rTrain = new int[numClasses][n]; for (int k = 0; k < numClasses; k++) { int[] rkTrain = rTrain[k]; double[] probkTrain = probTrain[k]; for (int i = 0; i < n; i++) { rkTrain[i] = MathUtils.indicator(target[i] == k); probkTrain[i] = 1.0 / numClasses; } } for (int iter = 0; iter < maxNumIters; iter++) { // Prepare attributes if (alpha < 1) { int[] a = perm.getPermutation(); for (int i = 0; i < indices.length; i++) { indices[i] = a[i]; } Arrays.sort(indices); List attList = trainSet.getAttributes(indices); trainSet.setAttributes(attList); } for (int k = 0; k < numClasses; k++) { // Prepare training set int[] rkTrain = rTrain[k]; double[] probkTrain = probTrain[k]; for (int i = 0; i < n; i++) { Instance instance = trainSet.get(i); double pk = probkTrain[i]; double t = rkTrain[i] - pk; double w = pk * (1 - pk); instance.setTarget(t * weight[i]); instance.setWeight(w * weight[i]); } RTree rt = (RTree) treeLearner.build(trainSet); rt.multiply(l); brt.trees[k].add(rt); double[] predkTrain = predTrain[k]; for (int i = 0; i < n; i++) { double p = rt.regress(trainSet.get(i)); predkTrain[i] += p; } } if (alpha < 1) { // Restore attributes trainSet.setAttributes(attributes); } // Update probabilities computeProbabilities(predTrain, probTrain); if (verbose) { double error = 0; for (int i = 0; i < n; i++) { double p = 0; double maxProb = -1; for (int k = 0; k < numClasses; k++) { if (probTrain[k][i] > maxProb) { maxProb = probTrain[k][i]; p = k; } } if (p != target[i]) { error++; } } error /= n; System.out.println("Iteration " + iter + ": " + error); } } // Restore targets and weights for (int i = 0; i < n; i++) { Instance instance = trainSet.get(i); instance.setTarget(target[i]); instance.setWeight(weight[i]); } treeLearner.evictCache(); return brt; } } @Override public void setTreeLearner(TreeLearner treeLearner) { if (!treeLearner.isRobust()) { throw new IllegalArgumentException("Only robust tree learners are accepted"); } this.treeLearner = treeLearner; } protected void computeProbabilities(double[] pred, double[] prob) { for (int i = 0; i < pred.length; i++) { prob[i] = MathUtils.sigmoid(pred[i]); } } protected void computeProbabilities(double[][] pred, double[][] prob) { for (int i = 0; i < pred[0].length; i++) { double max = Double.NEGATIVE_INFINITY; for (int k = 0; k < pred.length; k++) { if (max < pred[k][i]) { max = pred[k][i]; } } double sum = 0; for (int k = 0; k < pred.length; k++) { double p = Math.exp(pred[k][i] - max); prob[k][i] = p; sum += p; } for (int k = 0; k < pred.length; k++) { prob[k][i] /= sum; } } } } ================================================ FILE: src/main/java/mltk/predictor/tree/ensemble/brt/RobustDecisionTableLearner.java ================================================ package mltk.predictor.tree.ensemble.brt; import java.util.Collections; import java.util.List; import mltk.core.Instance; import mltk.core.Instances; import mltk.predictor.tree.DecisionTableLearner; import mltk.util.tuple.DoublePair; import mltk.util.tuple.IntDoublePair; /** * Class for learning decision tables in LogitBoost algorithm. * * @author Yin Lou * */ public class RobustDecisionTableLearner extends DecisionTableLearner { protected boolean getStats(Instances instances, double[] stats) { stats[0] = stats[1] = stats[2] = 0; if (instances.size() == 0) { return true; } double firstTarget = instances.get(0).getTarget(); boolean stdIs0 = true; for (Instance instance : instances) { double weight = instance.getWeight(); double target = instance.getTarget(); stats[0] += weight; // The key difference is we do not use weighted sum. stats[1] += target; if (stdIs0 && target != firstTarget) { stdIs0 = false; } } stats[2] = stats[1] / stats[0]; if (Double.isNaN(stats[2])) { stats[2] = 0; } return stdIs0; } protected void getHistogram(Instances instances, List pairs, List uniqueValues, double w, double s, List histogram) { if (pairs.size() > 0) { double lastValue = pairs.get(0).v2; double totalWeight = instances.get(pairs.get(0).v1).getWeight(); // The key difference is we do not use weighted sum. double sum = instances.get(pairs.get(0).v1).getTarget(); for (int i = 1; i < pairs.size(); i++) { IntDoublePair pair = pairs.get(i); double value = pair.v2; double weight = instances.get(pairs.get(i).v1).getWeight(); double resp = instances.get(pairs.get(i).v1).getTarget(); if (value != lastValue) { uniqueValues.add(lastValue); histogram.add(new DoublePair(totalWeight, sum)); lastValue = value; totalWeight = weight; sum = resp; } else { totalWeight += weight; sum += resp; } } uniqueValues.add(lastValue); histogram.add(new DoublePair(totalWeight, sum)); } if (pairs.size() != instances.size()) { // Zero entries are present double sumWeight = 0; double sumTarget = 0; for (DoublePair pair : histogram) { sumWeight += pair.v1; sumTarget += pair.v2; } double weightOnZero = w - sumWeight; double sumOnZero = s - sumTarget; int idx = Collections.binarySearch(uniqueValues, ZERO); if (idx < 0) { // This should always happen uniqueValues.add(-idx - 1, ZERO); histogram.add(-idx - 1, new DoublePair(weightOnZero, sumOnZero)); } } } } ================================================ FILE: src/main/java/mltk/predictor/tree/ensemble/brt/RobustRegressionTreeLearner.java ================================================ package mltk.predictor.tree.ensemble.brt; import java.util.Collections; import java.util.List; import mltk.core.Instance; import mltk.core.Instances; import mltk.predictor.tree.RegressionTreeLearner; import mltk.util.tuple.DoublePair; import mltk.util.tuple.IntDoublePair; /** * Class for learning regression trees in LogitBoost algorithm. The splitting criteria ignores the weights when * calculating sum of responses. * * @author Yin Lou * */ public class RobustRegressionTreeLearner extends RegressionTreeLearner { public boolean isRobust() { return true; } protected boolean getStats(Instances instances, double[] stats) { stats[0] = stats[1] = stats[2] = 0; if (instances.size() == 0) { return true; } double firstTarget = instances.get(0).getTarget(); boolean stdIs0 = true; for (Instance instance : instances) { double weight = instance.getWeight(); double target = instance.getTarget(); stats[0] += weight; // The key difference is we do not use weighted sum. stats[1] += target; if (stdIs0 && target != firstTarget) { stdIs0 = false; } } stats[2] = stats[1] / stats[0]; if (Double.isNaN(stats[2])) { stats[2] = 0; } return stdIs0; } protected void getHistogram(Instances instances, List pairs, List uniqueValues, double w, double s, List histogram) { if (pairs.size() > 0) { double lastValue = pairs.get(0).v2; double totalWeight = instances.get(pairs.get(0).v1).getWeight(); // The key difference is we do not use weighted sum. double sum = instances.get(pairs.get(0).v1).getTarget(); for (int i = 1; i < pairs.size(); i++) { IntDoublePair pair = pairs.get(i); double value = pair.v2; double weight = instances.get(pairs.get(i).v1).getWeight(); double resp = instances.get(pairs.get(i).v1).getTarget(); if (value != lastValue) { uniqueValues.add(lastValue); histogram.add(new DoublePair(totalWeight, sum)); lastValue = value; totalWeight = weight; sum = resp; } else { totalWeight += weight; sum += resp; } } uniqueValues.add(lastValue); histogram.add(new DoublePair(totalWeight, sum)); } if (pairs.size() != instances.size()) { // Zero entries are present double sumWeight = 0; double sumTarget = 0; for (DoublePair pair : histogram) { sumWeight += pair.v1; sumTarget += pair.v2; } double weightOnZero = w - sumWeight; double sumOnZero = s - sumTarget; int idx = Collections.binarySearch(uniqueValues, ZERO); if (idx < 0) { // This should always happen uniqueValues.add(-idx - 1, ZERO); histogram.add(-idx - 1, new DoublePair(weightOnZero, sumOnZero)); } } } } ================================================ FILE: src/main/java/mltk/predictor/tree/ensemble/brt/package-info.java ================================================ /** * Provides algorithms for fitting boosted regression trees (BRTs). */ package mltk.predictor.tree.ensemble.brt; ================================================ FILE: src/main/java/mltk/predictor/tree/ensemble/package-info.java ================================================ /** * Provides algorithms for tree ensemble methods. */ package mltk.predictor.tree.ensemble; ================================================ FILE: src/main/java/mltk/predictor/tree/ensemble/rf/RandomForest.java ================================================ package mltk.predictor.tree.ensemble.rf; import java.io.BufferedReader; import java.io.PrintWriter; import mltk.core.Instance; import mltk.predictor.Regressor; import mltk.predictor.tree.RTree; import mltk.predictor.tree.RegressionTree; import mltk.predictor.tree.ensemble.RTreeList; /** * Class for random forests. * * @author Yin Lou * */ public class RandomForest implements Regressor { protected RTreeList rtList; /** * Constructor. */ public RandomForest() { rtList = new RTreeList(); } /** * Constructor. * * @param capacity the capacity of this random forest. */ public RandomForest(int capacity) { rtList = new RTreeList(capacity); } @Override public void read(BufferedReader in) throws Exception { int capacity = Integer.parseInt(in.readLine().split(": ")[1]); rtList = new RTreeList(capacity); in.readLine(); for (int i = 0; i < capacity; i++) { in.readLine(); RegressionTree rt = new RegressionTree(); rt.read(in); rtList.add(rt); in.readLine(); } } @Override public void write(PrintWriter out) throws Exception { out.printf("[Predictor: %s]\n", this.getClass().getCanonicalName()); out.println("Ensemble: " + size()); out.println(); for (RTree rt : rtList) { rt.write(out); out.println(); } } @Override public RandomForest copy() { RandomForest copy = new RandomForest(size()); for (RTree rt : rtList) { copy.add(rt.copy()); } return copy; } @Override public double regress(Instance instance) { if (size() == 0) { return 0.0; } else { double prediction = 0.0; for (RTree rt : rtList) { prediction += rt.regress(instance); } return prediction / size(); } } /** * Adds a regression tree to the ensemble. * * @param rt the regression tree. */ public void add(RTree rt) { rtList.add(rt); } /** * Returns the tree at the specified position in this list. * * @param index the index of the element to return. * @return the tree at the specified position in this list. */ public RTree get(int index) { return rtList.get(index); } /** * Returns the list of regression trees. * * @return the list of regression trees. */ public RTreeList getTreeList() { return rtList; } /** * Returns the size of this random forest. * * @return the size of this random forest. */ public int size() { return rtList.size(); } } ================================================ FILE: src/main/java/mltk/predictor/tree/ensemble/rf/RandomForestLearner.java ================================================ package mltk.predictor.tree.ensemble.rf; import mltk.cmdline.Argument; import mltk.cmdline.CmdLineParser; import mltk.cmdline.options.LearnerOptions; import mltk.core.Instances; import mltk.core.Sampling; import mltk.core.io.InstancesReader; import mltk.predictor.Learner; import mltk.predictor.io.PredictorWriter; import mltk.predictor.tree.RegressionTreeLearner; import mltk.predictor.tree.RegressionTreeLearner.Mode; /** * Class for learning random forests. * * @author Yin Lou * */ public class RandomForestLearner extends Learner { static class Options extends LearnerOptions { @Argument(name = "-m", description = "construction mode:parameter. Construction mode can be alpha limited (a), depth limited (d), number of leaves limited (l) and minimum leaf size limited (s) (default: a:0.001)") String mode = "a:0.001"; @Argument(name = "-f", description = "number of features to consider (default: 1/3 of total number of features)") int numFeatures = -1; @Argument(name = "-b", description = "bagging iterations (default: 100)") int baggingIters = 100; } /** * Trains a random forest of regression trees. * * When bagging is turned off (b = 0), this procedure generates a single random regression tree. When the number of * features to consider is the number of total features, this procedure builds bagged tree. * *
	 * Usage: mltk.predictor.tree.ensemble.rf.RandomForestLearner
	 * -t	train set path
	 * [-r]	attribute file path
	 * [-o]	output model path
	 * [-V]	verbose (default: true)
	 * [-m]	construction mode:parameter. Construction mode can be alpha limited (a), depth limited (d), number of leaves limited (l) and minimum leaf size limited (s) (default: a:0.001)
	 * [-f]	number of features to consider
	 * [-b]	bagging iterations (default: 100)
	 * 
* * @param args the command line arguments. * @throws Exception */ public static void main(String[] args) throws Exception { Options opts = new Options(); CmdLineParser parser = new CmdLineParser(RandomForestLearner.class, opts); RandomRegressionTreeLearner rtLearner = new RandomRegressionTreeLearner(); try { parser.parse(args); String[] data = opts.mode.split(":"); if (data.length != 2) { throw new IllegalArgumentException(); } switch (data[0]) { case "a": rtLearner.setConstructionMode(Mode.ALPHA_LIMITED); rtLearner.setAlpha(Double.parseDouble(data[1])); break; case "d": rtLearner.setConstructionMode(Mode.DEPTH_LIMITED); rtLearner.setMaxDepth(Integer.parseInt(data[1])); break; case "l": rtLearner.setConstructionMode(Mode.NUM_LEAVES_LIMITED); rtLearner.setMaxNumLeaves(Integer.parseInt(data[1])); break; case "s": rtLearner.setConstructionMode(Mode.MIN_LEAF_SIZE_LIMITED); rtLearner.setMinLeafSize(Integer.parseInt(data[1])); default: throw new IllegalArgumentException(); } } catch (IllegalArgumentException e) { parser.printUsage(); System.exit(1); } Instances trainSet = InstancesReader.read(opts.attPath, opts.trainPath); rtLearner.setNumFeatures(opts.numFeatures); RandomForestLearner rfLearner = new RandomForestLearner(); rfLearner.setBaggingIterations(opts.baggingIters); rfLearner.setRegressionTreeLearner(rtLearner); rfLearner.setVerbose(opts.verbose); long start = System.currentTimeMillis(); RandomForest rf = rfLearner.build(trainSet); long end = System.currentTimeMillis(); System.out.println("Time: " + (end - start) / 1000.0 + " (s)."); if (opts.outputModelPath != null) { PredictorWriter.write(rf, opts.outputModelPath); } } private int baggingIters; private RegressionTreeLearner rtLearner; @Override public RandomForest build(Instances instances) { // Create bags Instances[] bags = Sampling.createBags(instances, baggingIters); RandomForest rf = new RandomForest(baggingIters); for (Instances bag : bags) { rf.add(rtLearner.build(bag)); } return rf; } /** * Constructor. */ public RandomForestLearner() { verbose = false; baggingIters = 100; rtLearner = new RandomRegressionTreeLearner(); } /** * Returns the number of bagging iterations. * * @return the number of bagging iterations. */ public int getBaggingIterations() { return baggingIters; } /** * Sets the number of bagging iterations. * * @param baggingIters the number of bagging iterations. */ public void setBaggingIterations(int baggingIters) { this.baggingIters = baggingIters; } /** * Returns the regression tree learner. * * @return the regression tree learner. */ public RegressionTreeLearner getRegressionTreeLearner() { return rtLearner; } /** * Sets the regression tree learner. * * @param rtLearner the regression tree learner. */ public void setRegressionTreeLearner(RegressionTreeLearner rtLearner) { this.rtLearner = rtLearner; } } ================================================ FILE: src/main/java/mltk/predictor/tree/ensemble/rf/RandomRegressionTreeLearner.java ================================================ package mltk.predictor.tree.ensemble.rf; import java.util.ArrayList; import java.util.HashSet; import java.util.List; import java.util.Set; import mltk.core.Attribute; import mltk.core.Instances; import mltk.predictor.tree.RegressionTree; import mltk.predictor.tree.RegressionTreeLeaf; import mltk.predictor.tree.RegressionTreeLearner; import mltk.predictor.tree.TreeInteriorNode; import mltk.predictor.tree.TreeNode; import mltk.util.Permutation; import mltk.util.Random; import mltk.util.tuple.DoublePair; import mltk.util.tuple.IntDoublePair; /** * Class for learning random regression trees. * * @author Yin Lou * */ public class RandomRegressionTreeLearner extends RegressionTreeLearner { protected int numFeatures; protected Permutation perm; /** * Constructor. */ public RandomRegressionTreeLearner() { numFeatures = -1; alpha = 0.01; mode = Mode.ALPHA_LIMITED; } /** * Returns the maximum number of features to consider for each node. * * @return the maximum number of features to consider for each node. */ public int getNumFeatures() { return numFeatures; } /** * Sets the maximum number of features to consider for each node. * * @param numFeatures the new maximum number of features. */ public void setNumFeatures(int numFeatures) { this.numFeatures = numFeatures; } @Override public RegressionTree build(Instances instances) { if (numFeatures < 0) { numFeatures = instances.getAttributes().size() / 3; } if (perm == null || perm.size() != instances.getAttributes().size()) { perm = new Permutation(instances.getAttributes().size()); } RegressionTree rt = null; switch (mode) { case ALPHA_LIMITED: rt = buildAlphaLimitedTree(instances, alpha); break; case NUM_LEAVES_LIMITED: rt = buildNumLeafLimitedTree(instances, maxNumLeaves); break; case DEPTH_LIMITED: rt = buildDepthLimitedTree(instances, maxDepth); break; case MIN_LEAF_SIZE_LIMITED: rt = buildMinLeafSizeLimitedTree(instances, minLeafSize); default: break; } return rt; } protected TreeNode createNode(Dataset dataset, int limit, double[] stats) { boolean stdIs0 = getStats(dataset.instances, stats); final double totalWeights = stats[0]; final double sum = stats[1]; final double weightedMean = stats[2]; // 1. Check basic leaf conditions if (stats[0] < limit || stdIs0) { RegressionTreeLeaf node = new RegressionTreeLeaf(weightedMean); return node; } // 2. Find best split double bestEval = Double.POSITIVE_INFINITY; List splits = new ArrayList<>(); List attributes = dataset.instances.getAttributes(); int[] a = perm.permute().getPermutation(); Set selected = new HashSet<>(numFeatures); for (int i = 0; i < numFeatures; i++) { selected.add(a[i]); } for (int j = 0; j < attributes.size(); j++) { int attIndex = attributes.get(j).getIndex(); if (!selected.contains(j)) { continue; } String attName = attributes.get(j).getName(); List sortedList = dataset.sortedLists.get(attName); List uniqueValues = new ArrayList<>(sortedList.size()); List histogram = new ArrayList<>(sortedList.size()); getHistogram(dataset.instances, sortedList, uniqueValues, totalWeights, sum, histogram); if (uniqueValues.size() > 1) { DoublePair split = split(uniqueValues, histogram, totalWeights, sum); if (split.v2 <= bestEval) { IntDoublePair splitPoint = new IntDoublePair(attIndex, split.v1); if (split.v2 < bestEval) { splits.clear(); bestEval = split.v2; } splits.add(splitPoint); } } } if (bestEval < Double.POSITIVE_INFINITY) { Random rand = Random.getInstance(); IntDoublePair splitPoint = splits.get(rand.nextInt(splits.size())); int attIndex = splitPoint.v1; TreeNode node = new TreeInteriorNode(attIndex, splitPoint.v2); stats[3] = bestEval + totalWeights * weightedMean * weightedMean; return node; } else { RegressionTreeLeaf node = new RegressionTreeLeaf(weightedMean); return node; } } } ================================================ FILE: src/main/java/mltk/predictor/tree/ensemble/rf/package-info.java ================================================ /** * Provides algorithms for fitting random forests (RFs). */ package mltk.predictor.tree.ensemble.rf; ================================================ FILE: src/main/java/mltk/predictor/tree/package-info.java ================================================ /** * Provides algorithms for tree-based methods. */ package mltk.predictor.tree; ================================================ FILE: src/main/java/mltk/util/ArrayUtils.java ================================================ package mltk.util; import java.util.Arrays; import java.util.List; /** * Class for utility functions for arrays. * * @author Yin Lou * */ public class ArrayUtils { /** * Converts an integer list to int array. * * @param list the list. * @return an int array. */ public static int[] toIntArray(List list) { int[] a = new int[list.size()]; for (int i = 0; i < list.size(); i++) { a[i] = list.get(i); } return a; } /** * Converts a double list to double array. * * @param list the list. * @return a double array. */ public static double[] toDoubleArray(List list) { double[] a = new double[list.size()]; for (int i = 0; i < list.size(); i++) { a[i] = list.get(i); } return a; } /** * Converts a double array to an int array. * * @param a the double array. * @return an int array. */ public static int[] toIntArray(double[] a) { int[] b = new int[a.length]; for (int i = 0; i < a.length; i++) { b[i] = (int) a[i]; } return b; } /** * Returns a string representation of the contents of the specified sub-array. * * @param a the array. * @param start the starting index (inclusive). * @param end the ending index (exclusive). * @return Returns a string representation of the contents of the specified sub-array. */ public static String toString(double[] a, int start, int end) { StringBuilder sb = new StringBuilder(); sb.append("[").append(a[start]); for (int i = start + 1; i < end; i++) { sb.append(", ").append(a[i]); } sb.append("]"); return sb.toString(); } /** * Parses a double array from a string (default delimiter: ","). * * @param str the string representation of a double array. * @return a double array. */ public static double[] parseDoubleArray(String str) { return parseDoubleArray(str, ","); } /** * Parses a double array from a string. * * @param str the string representation of a double array. * @param delimiter the delimiter. * @return a double array. */ public static double[] parseDoubleArray(String str, String delimiter) { if (str == null || str.equalsIgnoreCase("null")) { return null; } String[] data = str.substring(1, str.length() - 1).split(delimiter); double[] a = new double[data.length]; for (int i = 0; i < a.length; i++) { a[i] = Double.parseDouble(data[i].trim()); } return a; } /** * Parses an int array from a string (default delimiter: ","). * * @param str the string representation of an int array. * @return an int array. */ public static int[] parseIntArray(String str) { return parseIntArray(str, ","); } /** * Parses an int array from a string. * * @param str the string representation of an int array. * @param delimiter the delimiter. * @return an int array. */ public static int[] parseIntArray(String str, String delimiter) { if (str == null || str.equalsIgnoreCase("null")) { return null; } String[] data = str.substring(1, str.length() - 1).split(delimiter); int[] a = new int[data.length]; for (int i = 0; i < a.length; i++) { a[i] = Integer.parseInt(data[i].trim()); } return a; } /** * Parses a long array from a string (default delimiter: ","). * * @param str the string representation of a long array. * @return an long array. */ public static long[] parseLongArray(String str) { return parseLongArray(str, ","); } /** * Parses a long array from a string. * * @param str the string representation of a long array. * @param delimiter the delimiter. * @return a long array. */ public static long[] parseLongArray(String str, String delimiter) { if (str == null || str.equalsIgnoreCase("null")) { return null; } String[] data = str.substring(1, str.length() - 1).split(delimiter); long[] a = new long[data.length]; for (int i = 0; i < a.length; i++) { a[i] = Long.parseLong(data[i].trim()); } return a; } /** * Returns {@code true} if the specified range of an array is constant c. * * @param a the array. * @param begin the index of first element (inclusive). * @param end the index of last element (exclusive). * @param c the constant to test. * @return {@code true} if the specified range of an array is constant c. */ public static boolean isConstant(double[] a, int begin, int end, double c) { for (int i = begin; i < end; i++) { if (!MathUtils.equals(a[i], c)) { return false; } } return true; } /** * Returns {@code true} if the specified range of an array is constant c. * * @param a the array. * @param begin the index of first element (inclusive). * @param end the index of last element (exclusive). * @param c the constant to test. * @return {@code true} if the specified range of an array is constant c. */ public static boolean isConstant(int[] a, int begin, int end, int c) { for (int i = begin; i < end; i++) { if (a[i] != c) { return false; } } return true; } /** * Returns {@code true} if the specified range of an array is constant c. * * @param a the array. * @param begin the index of first element (inclusive). * @param end the index of last element (exclusive). * @param c the constant to test. * @return {@code true} if the specified range of an array is constant c. */ public static boolean isConstant(byte[] a, int begin, int end, byte c) { for (int i = begin; i < end; i++) { if (a[i] != c) { return false; } } return true; } /** * Returns the median of an array. * * @param a the array. * @return the median of an array. */ public static double getMedian(double[] a) { if (a.length == 0) { return 0; } double[] ary = Arrays.copyOf(a, a.length); Arrays.sort(ary); int mid = ary.length / 2; if (ary.length % 2 == 1) { return ary[mid]; } else { return (ary[mid - 1] + ary[mid]) / 2; } } /** * Returns the index of the search key if it is contained in the array, otherwise returns the insertion point. * * @param a the array. * @param key the search key. * @return the index of the search key if it is contained in the array, otherwise returns the insertion point. */ public static int findInsertionPoint(double[] a, double key) { int idx = Arrays.binarySearch(a, key); if (idx < 0) { idx = -idx - 1; } return idx; } } ================================================ FILE: src/main/java/mltk/util/Element.java ================================================ package mltk.util; /** * Class for weighted elements. * * @author Yin Lou * * @param the type of the element. */ public class Element implements Comparable> { public T element; public double weight; /** * Constructs a weighted element. * * @param element the element. * @param weight the weight. */ public Element(T element, double weight) { this.element = element; this.weight = weight; } @Override public int compareTo(Element e) { return Double.compare(this.weight, e.weight); } } ================================================ FILE: src/main/java/mltk/util/MathUtils.java ================================================ package mltk.util; /** * Class for utility functions for math. * * @author Yin Lou * */ public class MathUtils { /** * 1e-8 */ public static final double EPSILON = 1e-8; /** * log(2) */ public static final double LOG2 = Math.log(2); /** * Returns {@code true} if two doubles are equal to within {@link mltk.util.MathUtils#EPSILON}. * * @param a the 1st number. * @param b the 2nd number. * @return {@code true} if two doubles are equal to within {@link mltk.util.MathUtils#EPSILON}. */ public static boolean equals(double a, double b) { return Math.abs(a - b) < EPSILON; } /** * Returns 1 if the input is true and 0 otherwise. * * @param b the input. * @return 1 if the input is true and 0 otherwise. */ public static int indicator(boolean b) { return b ? 1 : 0; } /** * Returns {@code true} if the first value is better. * * @param a the 1st value. * @param b the 2nd value. * @param isLargerBetter {@code true} if the first value is better. * @return {@code true} if the first value is better. */ public static boolean isFirstBetter(double a, double b, boolean isLargerBetter) { if (isLargerBetter) { return a > b; } else { return a < b; } } /** * Returns {@code true} if the floating number is integer. * * @param v the floating number. * @return {@code true} if the floating number is integer. */ public static boolean isInteger(double v) { return (v % 1) == 0; } /** * Returns {@code true} if the floating number is zero. * * @param v the floating number. * @return {@code true} if the floating number is zero. */ public static boolean isZero(double v) { return Math.abs(v) < EPSILON; } /** * Returns the value of a sigmoid function. * * @param a the number. * @return the value of a sigmoid function. */ public static double sigmoid(double a) { return 1 / (1 + Math.exp(-a)); } /** * Returns the sign of a number. * * @param a the number. * @return the sign of a number. */ public static int sign(double a) { if (a < 0) { return -1; } else if (a > 0) { return 1; } else { return 0; } } /** * Returns the sign of a number. * * @param a the number. * @return the sign of a number. */ public static int sign(int a) { if (a < 0) { return -1; } else if (a > 0) { return 1; } else { return 0; } } /** * Performs division and returns default value when division by zero. * * @param a the numerator. * @param b the denominator. * @param dv the default value. * @return a / b or default value when division by zero. */ public static double divide(double a, double b, double dv) { return isZero(b) ? dv : a / b; } } ================================================ FILE: src/main/java/mltk/util/OptimUtils.java ================================================ package mltk.util; import java.util.List; /** * Class for utility functions for optimization. * * @author Yin Lou * */ public class OptimUtils { /** * Returns the gain for variance reduction. This method is mostly used * in tree learners. * * @param sum the sum of responses. * @param weight the total weight. * @return the gain for variance reduction. */ public static double getGain(double sum, double weight) { if (weight < MathUtils.EPSILON) { return 0; } else { return sum / weight * sum; } } /** * Returns the probability of being in positive class. * * @param pred the prediction. * @return the probability of being in positive class. */ public static double getProbability(double pred) { return 1.0 / (1.0 + Math.exp(-pred)); } /** * Returns the residual. * * @param pred the prediction. * @param target the target. * @return the residual. */ public static double getResidual(double pred, double target) { return target - pred; } /** * Returns the pseudo residual. * * @param pred the prediction. * @param cls the class label. * @return the pseudo residual. */ public static double getPseudoResidual(double pred, double cls) { return cls - getProbability(pred); } /** * Computes the pseudo residuals. * * @param prediction the prediction array. * @param y the class label array. * @param residual the residual array. */ public static void computePseudoResidual(double[] prediction, double[] y, double[] residual) { for (int i = 0; i < residual.length; i++) { residual[i] = getPseudoResidual(prediction[i], y[i]); } } /** * Computes the probabilities. * * @param pred the prediction array. * @param prob the probability array. */ public static void computeProbabilities(double[] pred, double[] prob) { for (int i = 0; i < pred.length; i++) { prob[i] = getProbability(pred[i]); } } /** * Computes the logistic loss for binary classification problems. * * @param pred the prediction. * @param cls the class label. * @return the logistic loss for binary classification problems. */ public static double computeLogisticLoss(double pred, double cls) { if (cls == 1) { return Math.log(1 + Math.exp(-pred)); } else { return Math.log(1 + Math.exp(pred)); } } /** * Computes the logistic loss for binary classification problems. * * @param pred the prediction array. * @param y the class label array. * @return the logistic loss for binary classification problems. */ public static double computeLogisticLoss(double[] pred, double[] y) { double loss = 0; for (int i = 0; i < pred.length; i++) { loss += computeLogisticLoss(pred[i], y[i]); } return loss / y.length; } /** * Computes the log loss (cross entropy) for binary classification problems. * * @param prob the probability. * @param y the class label. * @return the log loss. */ public static double computeLogLoss(double prob, double y) { return computeLogLoss(prob, y, false); } /** * Computes the log loss (cross entropy) for binary classification problems. * * @param p the input. * @param y the class label. * @param isRawScore {@code true} if the input is raw score. * @return the log loss. */ public static double computeLogLoss(double p, double y, boolean isRawScore) { if (isRawScore) { p = MathUtils.sigmoid(p); } if (y == 1) { return -Math.log(p); } else { return -Math.log(1 - p); } } /** * Computes the log loss (cross entropy) for binary classification problems. * * @param prob the probabilities. * @param y the class label array. * @return the log loss. */ public static double computeLogLoss(double[] prob, double[] y) { return computeLogLoss(prob, y, false); } /** * Computes the log loss (cross entropy) for binary classification problems. * * @param p the input. * @param y the targets * @param isRawScore {@code true} if the input is raw score. * @return the log loss. */ public static double computeLogLoss(double[] p, double[] y, boolean isRawScore) { double logLoss = 0; for (int i = 0; i < p.length; i++) { logLoss += computeLogLoss(p[i], y[i], isRawScore); } return logLoss; } /** * Computes the quadratic loss for regression problems. * * @param residual the residual array. * @return the quadratic loss for regression problems. */ public static double computeQuadraticLoss(double[] residual) { return StatUtils.sumSq(residual) / (2 * residual.length); } /** * Returns gradient on the intercept in regression problems. Residuals will be updated accordingly. * * @param residual the residual array. * @return the fitted intercept. */ public static double fitIntercept(double[] residual) { double delta = StatUtils.mean(residual); VectorUtils.subtract(residual, delta); return delta; } /** * Returns gradient on the intercept in binary classification problems. Predictions and residuals will be updated accordingly. * * @param prediction the prediction array. * @param residual the residual array. * @param y the class label array. * @return the fitted intercept. */ public static double fitIntercept(double[] prediction, double[] residual, double[] y) { double delta = 0; // Use Newton-Raphson's method to approximate // 1st derivative double eta = 0; // 2nd derivative double theta = 0; for (int i = 0; i < prediction.length; i++) { double r = residual[i]; double t = Math.abs(r); eta += r; theta += t * (1 - t); } if (Math.abs(theta) > MathUtils.EPSILON) { delta = eta / theta; // Update predictions VectorUtils.add(prediction, delta); computePseudoResidual(prediction, y, residual); } return delta; } /** * Returns {@code true} if the relative improvement is less than a threshold. * * @param prevLoss the previous loss. * @param currLoss the current loss. * @param epsilon the threshold. * @return {@code true} if the relative improvement is less than a threshold. */ public static boolean isConverged(double prevLoss, double currLoss, double epsilon) { if (prevLoss < MathUtils.EPSILON) { return true; } else { return (prevLoss - currLoss) / prevLoss < epsilon; } } /** * Returns {@code true} if the array of metric values is converged. * * @param p an array of metric values. * @param isLargerBetter {@code true} if larger value is better. * @return {@code true} if the list of metric values is converged. */ public static boolean isConverged(double[] p, boolean isLargerBetter) { final int bn = p.length; if (p.length <= 20) { return false; } double bestPerf = p[bn - 1]; double worstPerf = p[bn - 20]; for (int i = bn - 20; i < bn; i++) { if (MathUtils.isFirstBetter(p[i], bestPerf, isLargerBetter)) { bestPerf = p[i]; } if (!MathUtils.isFirstBetter(p[i], worstPerf, isLargerBetter)) { worstPerf = p[i]; } } double relMaxMin = Math.abs(worstPerf - bestPerf) / worstPerf; double relImprov; if (MathUtils.isFirstBetter(p[bn - 1], p[bn - 21], isLargerBetter)) { relImprov = Math.abs(p[bn - 21] - p[bn - 1]) / p[bn - 21]; } else { // Overfitting relImprov = Double.NaN; } return relMaxMin < 0.02 && (Double.isNaN(relImprov) || relImprov < 0.005); } /** * Returns {@code true} if the list of metric values is converged. * * @param list a list of metric values. * @param isLargerBetter {@code true} if larger value is better. * @return {@code true} if the list of metric values is converged. */ public static boolean isConverged(List list, boolean isLargerBetter) { if (list.size() <= 20) { return false; } final int bn = list.size(); double bestPerf = list.get(bn - 1); double worstPerf = list.get(bn - 20); for (int i = bn - 20; i < bn; i++) { if (MathUtils.isFirstBetter(list.get(i), bestPerf, isLargerBetter)) { bestPerf = list.get(i); } if (!MathUtils.isFirstBetter(list.get(i), worstPerf, isLargerBetter)) { worstPerf = list.get(i); } } double relMaxMin = Math.abs(worstPerf - bestPerf) / worstPerf; double relImprov; if (MathUtils.isFirstBetter(list.get(bn - 1), list.get(bn - 21), isLargerBetter)) { relImprov = Math.abs(list.get(bn - 21) - list.get(bn - 1)) / list.get(bn - 21); } else { // Overfitting relImprov = Double.NaN; } return relMaxMin < 0.02 && (Double.isNaN(relImprov) || relImprov < 0.005); } } ================================================ FILE: src/main/java/mltk/util/Permutation.java ================================================ package mltk.util; /** * Class for handling permutation. * * @author Yin Lou * */ public class Permutation { protected int[] a; /** * Initializes a permutation of length n. * * @param n the length of a permutation. */ public Permutation(int n) { a = new int[n]; for (int i = 0; i < a.length; i++) { a[i] = i; } } /** * Randomly permutes this permutation. * * @return this permutation. */ public Permutation permute() { for (int i = a.length - 1; i > 0; i--) { int idx = Random.getInstance().nextInt(i + 1); int t = a[idx]; a[idx] = a[i]; a[i] = t; } return this; } /** * Returns the size of this permutation. * * @return the size of this permutation. */ public int size() { return a.length; } /** * Returns the permutation. * * @return the permutation. */ public int[] getPermutation() { return a; } } ================================================ FILE: src/main/java/mltk/util/Queue.java ================================================ package mltk.util; import java.util.LinkedList; /** * Class for generic queues. * * @author Yin Lou * * @param the type of the queue. */ public class Queue { protected LinkedList list; /** * Constructor. */ public Queue() { list = new LinkedList(); } /** * Inserts an item to the queue. * * @param item the item. */ public void enqueue(T item) { list.addLast(item); } /** * Removes the first element in the queue. * * @return the first element in the queue. */ public T dequeue() { T item = list.getFirst(); list.removeFirst(); return item; } /** * Returns {@code true} if the queue is empty. * * @return {@code true} if the queue is empty. */ public boolean isEmpty() { return list.size() == 0; } } ================================================ FILE: src/main/java/mltk/util/Random.java ================================================ package mltk.util; /** * Class for global random object. * * @author Yin Lou * */ public class Random { protected static Random instance = null; protected java.util.Random rand; protected Random() { rand = new java.util.Random(); } /** * Returns the random object. * * @return the singleton random object. */ public static Random getInstance() { if (instance == null) { instance = new Random(); } return instance; } /** * Sets the random seed. * * @param seed the random seed. */ public void setSeed(long seed) { rand.setSeed(seed); } /** * Returns the next pseudorandom, uniformly distributed int value from this random number generator's * sequence. * * @return a random integer. */ public int nextInt() { return rand.nextInt(); } /** * Returns the next pseudorandom, uniformly distributed int value between 0 (inclusive) and n * (exclusive) from this random number generator's sequence. * * @param n the range. * @return a random integer in [0, n- 1]. */ public int nextInt(int n) { return rand.nextInt(n); } /** * Returns the next pseudorandom, uniformly distributed double value between 0.0 and 1.0 from this * random number generator's sequence. * * @return a random double value. */ public double nextDouble() { return rand.nextDouble(); } /** * Returns the next pseudorandom, uniformly distributed float value between 0.0 and 1.0 from this * random number generator's sequence. * * @return a random float value. */ public float nextFloat() { return rand.nextFloat(); } /** * Returns the next pseudorandom, Gaussian ("normally") distributed * double value with mean 0.0 and standard deviation 1.0 from this random number generator's sequence. * * @return a random double value. */ public double nextGaussian() { return rand.nextGaussian(); } /** * Returns the next pseudorandom, uniformly distributed long value from this random number generator's * sequence. * * @return a random long value. */ public long nextLong() { return rand.nextLong(); } /** * Returns the next pseudorandom, uniformly distributed boolean value from this random number * generator's sequence. * * @return a random boolean value. */ public boolean nextBoolean() { return rand.nextBoolean(); } /** * Generates random bytes and places them into a user-supplied byte array. * * @param bytes the byte array to fill with random bytes. */ public void nextBytes(byte[] bytes) { rand.nextBytes(bytes); } /** * Returns the backend Java random object. * * @return the backend Java random object. */ public java.util.Random getRandom() { return rand; } } ================================================ FILE: src/main/java/mltk/util/Stack.java ================================================ package mltk.util; import java.util.ArrayList; import java.util.EmptyStackException; import java.util.List; /** * Class for generic stacks. * * @author Yin Lou * * @param the type of this stack. */ public class Stack { protected List list; /** * Constructor. */ public Stack() { this.list = new ArrayList(); } /** * Inserts an item into the stack. * * @param item the item. */ public void push(T item) { list.add(item); } /** * Looks at the object at the top of this stack without removing it from the stack. * * @return the top element in the stack. */ public T peek() { if (list.size() == 0) { throw new EmptyStackException(); } return list.get(list.size() - 1); } /** * Removes the object at the top of this stack and returns that object as the value of this function. * * @return the top element in the stack. */ public T pop() { T item = peek(); list.remove(list.size() - 1); return item; } /** * Returns {@code true} if the stack is empty. * * @return {@code true} if the stack is empty. */ public boolean isEmpty() { return list.size() == 0; } } ================================================ FILE: src/main/java/mltk/util/StatUtils.java ================================================ package mltk.util; /** * Class for utility functions for computing statistics. * * @author Yin Lou * */ public class StatUtils { /** * Returns the maximum element in an array. * * @param a the array. * @return the maximum element in an array. */ public static int max(int[] a) { int max = a[0]; for (int i = 1; i < a.length; i++) { if (a[i] > max) { max = a[i]; } } return max; } /** * Returns the maximum element in an array. * * @param a the array. * @return the maximum element in an array. */ public static double max(double[] a) { double max = a[0]; for (int i = 1; i < a.length; i++) { if (a[i] > max) { max = a[i]; } } return max; } /** * Returns the index of maximum element. * * @param a the array. * @return the index of maximum element. */ public static int indexOfMax(int[] a) { int max = a[0]; int idx = 0; for (int i = 1; i < a.length; i++) { if (a[i] > max) { max = a[i]; idx = i; } } return idx; } /** * Returns the index of maximum element. * * @param a the array. * @return the index of maximum element. */ public static int indexOfMax(double[] a) { double max = a[0]; int idx = 0; for (int i = 1; i < a.length; i++) { if (a[i] > max) { max = a[i]; idx = i; } } return idx; } /** * Returns the minimum element in an array. * * @param a the array. * @return the minimum element in an array. */ public static int min(int[] a) { int min = a[0]; for (int i = 1; i < a.length; i++) { if (a[i] < min) { min = a[i]; } } return min; } /** * Returns the minimum element in an array. * * @param a the array. * @return the minimum element in an array. */ public static double min(double[] a) { double min = a[0]; for (int i = 1; i < a.length; i++) { if (a[i] < min) { min = a[i]; } } return min; } /** * Returns the index of minimum element. * * @param a the array. * @return the index of minimum element. */ public static int indexOfMin(int[] a) { int min = a[0]; int idx = 0; for (int i = 1; i < a.length; i++) { if (a[i] < min) { min = a[i]; idx = i; } } return idx; } /** * Returns the index of minimum element. * * @param a the array. * @return the index of minimum element. */ public static int indexOfMin(double[] a) { double min = a[0]; int idx = 0; for (int i = 1; i < a.length; i++) { if (a[i] < min) { min = a[i]; idx = i; } } return idx; } /** * Returns the sum of elements in an array. * * @param a the array. * @return the sum of elements in an array. */ public static double sum(double[] a) { double sum = 0; for (double v : a) { sum += v; } return sum; } /** * Returns the sum of squares. * * @param a the array. * @return the sum of squares. */ public static double sumSq(double[] a) { return sumSq(a, 0, a.length); } /** * Returns the sum of squares within a specific range. * * @param a the array. * @param fromIndex the index of the first element (inclusive). * @param toIndex the index of the last element (exclusive). * @return the sum of squares. */ public static double sumSq(double[] a, int fromIndex, int toIndex) { double sq = 0.0; for (int i = fromIndex; i < toIndex; i++) { sq += a[i] * a[i]; } return sq; } /** * Returns the mean. * * @param a the array. * @return the mean. */ public static double mean(double[] a) { return mean(a, a.length); } /** * Returns the mean. * * @param a the array. * @param n the total number of elements. * @return the mean. */ public static double mean(double[] a, int n) { double avg = 0.0; for (double v : a) { avg += v; } return avg / n; } /** * Returns the variance. * * @param a the array. * @return the variance. */ public static double variance(double[] a) { return variance(a, a.length); } /** * Returns the variance. * * @param a the array. * @param n the total number of elements. * @return the variance. */ public static double variance(double[] a, int n) { double avg = mean(a, n); double sq = 0.0; for (double v : a) { double d = v - avg; sq += d * d; } return sq / (n - 1.0); } /** * Returns the standard variance. * * @param a the array. * @return the standard variance. */ public static double sd(double[] a) { return sd(a, a.length); } /** * Returns the standard variance. * * @param a the array. * @param n the total number of elements. * @return the standard variance. */ public static double sd(double[] a, int n) { return Math.sqrt(variance(a, n)); } /** * Returns the root mean square. * * @param a the array. * @return the root mean square. */ public static double rms(double[] a) { double rms = 0.0; for (double v : a) { rms += v * v; } rms /= a.length; return Math.sqrt(rms); } /** * Returns the mean absolute deviation around a central point. * * @param a the array. * @param centralPoint the central point. * @return the mean absolute deviation around a central point. */ public static double mad(double[] a, double centralPoint) { double mad = 0.0; for (double v : a) { mad += Math.abs(v - centralPoint); } return mad / a.length; } } ================================================ FILE: src/main/java/mltk/util/UFSets.java ================================================ package mltk.util; /** * Class for union-find sets. * * @author Yin Lou * */ public class UFSets { private int[] parent; /** * Constructor. * * @param size the size. */ public UFSets(int size) { parent = new int[size + 1]; for (int i = 0; i < parent.length; i++) { parent[i] = -1; } } /** * Unions two sets. * * @param root1 the root for the 1st set. * @param root2 the root for the 2nd set. */ public void union(int root1, int root2) { int temp = parent[root1] + parent[root2]; if (parent[root1] < parent[root2]) { parent[root2] = root1; parent[root1] = temp; } else { parent[root1] = root2; parent[root2] = temp; } } /** * Returns the root of the set that contains the search key. * * @param i the search key. * @return the root of the set that contains the search key. */ public int find(int i) { int j; for (j = i; parent[j] >= 0; j = parent[j]); while (i != j) { int temp = parent[i]; parent[i] = j; i = temp; } return j; } } ================================================ FILE: src/main/java/mltk/util/VectorUtils.java ================================================ package mltk.util; import mltk.core.DenseVector; import mltk.core.SparseVector; import mltk.core.Vector; /** * Class for utility functions for real vectors. * * @author Yin Lou * */ public class VectorUtils { /** * Adds a constant to all elements in the array. * * @param a the vector. * @param v the constant. */ public static void add(double[] a, double v) { for (int i = 0; i < a.length; i++) { a[i] += v; } } /** * Subtracts a constant from all elements in the array. * * @param a the vector. * @param v the constant. */ public static void subtract(double[] a, double v) { for (int i = 0; i < a.length; i++) { a[i] -= v; } } /** * Multiplies a constant to all elements in the array. * * @param a the vector. * @param v the constant. */ public static void multiply(double[] a, double v) { for (int i = 0; i < a.length; i++) { a[i] *= v; } } /** * Divides a constant to all elements in the array. * * @param a the vector. * @param v the constant. */ public static void divide(double[] a, double v) { for (int i = 0; i < a.length; i++) { a[i] /= v; } } /** * Returns the L2 norm of a vector. * * @param a the vector. * @return the L2 norm of a vector. */ public static double l2norm(double[] a) { return Math.sqrt(StatUtils.sumSq(a)); } /** * Returns the L2 norm of a vector. * * @param v the vector. * @return the L2 norm of a vector. */ public static double l2norm(Vector v) { return l2norm(v.getValues()); } /** * Returns the L1 norm of a vector. * * @param a the vector. * @return the L1 norm of a vector. */ public static double l1norm(double[] a) { double norm = 0; for (double v : a) { norm += Math.abs(v); } return norm; } /** * Returns the L1 norm of a vector. * * @param v the vector. * @return the L1 norm of a vector. */ public static double l1norm(Vector v) { return l1norm(v.getValues()); } /** * Returns the dot product of two vectors. * * @param a the 1st vector. * @param b the 2nd vector. * @return the dot product of two vectors. */ public static double dotProduct(double[] a, double[] b) { double s = 0; for (int i = 0; i < a.length; i++) { s += a[i] * b[i]; } return s; } /** * Returns the dot product of two vectors. * * @param a the 1st vector. * @param b the 2nd vector. * @return the dot product of two vectors. */ public static double dotProduct(DenseVector a, DenseVector b) { return dotProduct(a.getValues(), b.getValues()); } /** * Returns the dot product of two vectors. * * @param a the 1st vector. * @param b the 2nd vector. * @return the dot product of two vectors. */ public static double dotProduct(SparseVector a, DenseVector b) { int[] indices1 = a.getIndices(); double[] values1 = a.getValues(); double[] values2 = b.getValues(); double s = 0; for (int i = 0; i < indices1.length; i++) { s += values1[i] * values2[indices1[i]]; } return s; } /** * Returns the dot product of two vectors. * * @param a the 1st vector. * @param b the 2nd vector. * @return the dot product of two vectors. */ public static double dotProduct(DenseVector a, SparseVector b) { return dotProduct(b, a); } /** * Returns the dot product of two vectors. * * @param a the 1st vector. * @param b the 2nd vector. * @return the dot product of two vectors. */ public static double dotProduct(SparseVector a, SparseVector b) { int[] indices1 = a.getIndices(); double[] values1 = a.getValues(); int[] indices2 = b.getIndices(); double[] values2 = b.getValues(); double s = 0; int i = 0; int j = 0; while (i < indices1.length && j < indices2.length) { if (indices1[i] < indices2[j]) { i++; } else if (indices1[i] > indices2[j]) { j++; } else { s += values1[i] * values2[j]; i++; j++; } } return s; } /** * Returns the Pearson correlation coefficient between two vectors. * * @param a the 1st vector. * @param b the 2nd vector. * @return the Pearson correlation coefficient between two vectors. */ public static double correlation(double[] a, double[] b) { double mean1 = StatUtils.mean(a); double mean2 = StatUtils.mean(b); double x = 0; double s1 = 0; double s2 = 0; for (int i = 0; i < a.length; i++) { double d1 = (a[i] - mean1); double d2 = (b[i] - mean2); x += d1 * d2; s1 += d1 * d1; s2 += d2 * d2; } return x / Math.sqrt(s1 * s2); } } ================================================ FILE: src/main/java/mltk/util/package-info.java ================================================ /** * Contains miscellaneous utility classes. */ package mltk.util; ================================================ FILE: src/main/java/mltk/util/tuple/DoublePair.java ================================================ package mltk.util.tuple; /** * CLass for $lt;double, double$gt; pair. * * @author Yin Lou * */ public class DoublePair { public double v1; public double v2; /** * Constructor. * * @param v1 the 1st double. * @param v2 the 2nd double. */ public DoublePair(double v1, double v2) { this.v1 = v1; this.v2 = v2; } @Override public int hashCode() { final int prime = 31; int result = 1; long temp; temp = Double.doubleToLongBits(v1); result = prime * result + (int) (temp ^ (temp >>> 32)); temp = Double.doubleToLongBits(v2); result = prime * result + (int) (temp ^ (temp >>> 32)); return result; } @Override public boolean equals(Object obj) { if (this == obj) return true; if (obj == null) return false; if (getClass() != obj.getClass()) return false; DoublePair other = (DoublePair) obj; if (Double.doubleToLongBits(v1) != Double.doubleToLongBits(other.v1)) return false; if (Double.doubleToLongBits(v2) != Double.doubleToLongBits(other.v2)) return false; return true; } } ================================================ FILE: src/main/java/mltk/util/tuple/IntDoublePair.java ================================================ package mltk.util.tuple; /** * Class for <int, double> pairs. * * @author Yin Lou * */ public class IntDoublePair { public int v1; public double v2; /** * Constructor. * * @param v1 the int value. * @param v2 the double value. */ public IntDoublePair(int v1, double v2) { this.v1 = v1; this.v2 = v2; } @Override public int hashCode() { final int prime = 31; int result = 1; result = prime * result + v1; long temp; temp = Double.doubleToLongBits(v2); result = prime * result + (int) (temp ^ (temp >>> 32)); return result; } @Override public boolean equals(Object obj) { if (this == obj) return true; if (obj == null) return false; if (getClass() != obj.getClass()) return false; IntDoublePair other = (IntDoublePair) obj; if (v1 != other.v1) return false; if (Double.doubleToLongBits(v2) != Double.doubleToLongBits(other.v2)) return false; return true; } } ================================================ FILE: src/main/java/mltk/util/tuple/IntDoublePairComparator.java ================================================ package mltk.util.tuple; import java.util.Comparator; /** * Class for comparing <int, double> pairs. By default int is used as key, and in ascending order. * * @author Yin Lou * */ public class IntDoublePairComparator implements Comparator { protected boolean ascending; protected boolean firstIsKey; public IntDoublePairComparator() { this(true, true); } public IntDoublePairComparator(boolean firstIsKey) { this(firstIsKey, true); } public IntDoublePairComparator(boolean firstIsKey, boolean ascending) { this.firstIsKey = firstIsKey; this.ascending = ascending; } @Override public int compare(IntDoublePair o1, IntDoublePair o2) { int cmp = 0; if (firstIsKey) { cmp = Integer.compare(o1.v1, o2.v1); } else { cmp = Double.compare(o1.v2, o2.v2); } if (!ascending) { cmp = -cmp; } return cmp; } } ================================================ FILE: src/main/java/mltk/util/tuple/IntPair.java ================================================ package mltk.util.tuple; /** * Class for <int, int> pairs. * * @author Yin Lou * */ public class IntPair { public int v1; public int v2; /** * Constructor. * * @param v1 the 1st int. * @param v2 the 2nd int. */ public IntPair(int v1, int v2) { this.v1 = v1; this.v2 = v2; } @Override public int hashCode() { final int prime = 31; int result = 1; result = prime * result + v1; result = prime * result + v2; return result; } @Override public boolean equals(Object obj) { if (this == obj) return true; if (obj == null) return false; if (getClass() != obj.getClass()) return false; IntPair other = (IntPair) obj; if (v1 != other.v1) return false; if (v2 != other.v2) return false; return true; } } ================================================ FILE: src/main/java/mltk/util/tuple/IntTriple.java ================================================ package mltk.util.tuple; /** * Class for <int, int, int> triples. * * @author Yin Lou * */ public class IntTriple { public int v1; public int v2; public int v3; /** * Constructor. * * @param v1 the 1st int. * @param v2 the 2nd int. * @param v3 the 3rd int. */ public IntTriple(int v1, int v2, int v3) { this.v1 = v1; this.v2 = v2; this.v3 = v3; } @Override public int hashCode() { final int prime = 31; int result = 1; result = prime * result + v1; result = prime * result + v2; result = prime * result + v3; return result; } @Override public boolean equals(Object obj) { if (this == obj) return true; if (obj == null) return false; if (getClass() != obj.getClass()) return false; IntTriple other = (IntTriple) obj; if (v1 != other.v1) return false; if (v2 != other.v2) return false; if (v3 != other.v3) return false; return true; } } ================================================ FILE: src/main/java/mltk/util/tuple/LongDoublePair.java ================================================ package mltk.util.tuple; /** * Class for <long, double> pairs. * * @author Yin Lou * */ public class LongDoublePair { public long v1; public double v2; /** * Constructor. * * @param v1 the long value. * @param v2 the double value. */ public LongDoublePair(long v1, double v2) { this.v1 = v1; this.v2 = v2; } @Override public int hashCode() { final int prime = 31; int result = 1; result = prime * result + (int) (v1 ^ (v1 >>> 32)); long temp; temp = Double.doubleToLongBits(v2); result = prime * result + (int) (temp ^ (temp >>> 32)); return result; } @Override public boolean equals(Object obj) { if (this == obj) return true; if (obj == null) return false; if (getClass() != obj.getClass()) return false; LongDoublePair other = (LongDoublePair) obj; if (v1 != other.v1) return false; if (Double.doubleToLongBits(v2) != Double.doubleToLongBits(other.v2)) return false; return true; } @Override public String toString() { return "(" + v1 + ", " + v2 + ")"; } } ================================================ FILE: src/main/java/mltk/util/tuple/LongDoublePairComparator.java ================================================ package mltk.util.tuple; import java.util.Comparator; /** * Class for comparing <long, double> pairs. By default long is used as key, and in ascending order. * * @author Yin Lou * */ public class LongDoublePairComparator implements Comparator { protected boolean ascending; protected boolean firstIsKey; public LongDoublePairComparator() { this(true, true); } public LongDoublePairComparator(boolean firstIsKey) { this(firstIsKey, true); } public LongDoublePairComparator(boolean firstIsKey, boolean ascending) { this.firstIsKey = firstIsKey; this.ascending = ascending; } @Override public int compare(LongDoublePair o1, LongDoublePair o2) { int cmp = 0; if (firstIsKey) { cmp = Long.compare(o1.v1, o2.v1); } else { cmp = Double.compare(o1.v2, o2.v2); } if (!ascending) { cmp = -cmp; } return cmp; } } ================================================ FILE: src/main/java/mltk/util/tuple/LongPair.java ================================================ package mltk.util.tuple; /** * Class for <long, long> pairs. * * @author Yin Lou * */ public class LongPair { public long v1; public long v2; /** * Constructor. * * @param v1 the 1st long. * @param v2 the 2nd long. */ public LongPair(long v1, long v2) { this.v1 = v1; this.v2 = v2; } @Override public int hashCode() { final int prime = 31; int result = 1; result = prime * result + (int) (v1 ^ (v1 >>> 32)); result = prime * result + (int) (v2 ^ (v2 >>> 32)); return result; } @Override public boolean equals(Object obj) { if (this == obj) return true; if (obj == null) return false; if (getClass() != obj.getClass()) return false; LongPair other = (LongPair) obj; if (v1 != other.v1) return false; if (v2 != other.v2) return false; return true; } } ================================================ FILE: src/main/java/mltk/util/tuple/Pair.java ================================================ package mltk.util.tuple; /** * Class for generic pairs. * * @author Yin Lou * * @param the type of the 1st element. * @param the type of the 2nd element. */ public class Pair { public T1 v1; public T2 v2; /** * Constructor. * * @param v1 the 1st element. * @param v2 the 2nd element. */ public Pair(T1 v1, T2 v2) { this.v1 = v1; this.v2 = v2; } @Override public int hashCode() { final int prime = 31; int result = 1; result = prime * result + ((v1 == null) ? 0 : v1.hashCode()); result = prime * result + ((v2 == null) ? 0 : v2.hashCode()); return result; } @Override public boolean equals(Object obj) { if (this == obj) return true; if (obj == null) return false; if (getClass() != obj.getClass()) return false; Pair other = (Pair) obj; if (v1 == null) { if (other.v1 != null) return false; } else if (v1.getClass() != other.v1.getClass()) { return false; } else if (!v1.equals(other.v1)) { return false; } if (v2 == null) { if (other.v2 != null) return false; } else if (v2.getClass() != other.v2.getClass()) { return false; } else if (!v2.equals(other.v2)) { return false; } return true; } } ================================================ FILE: src/main/java/mltk/util/tuple/Triple.java ================================================ package mltk.util.tuple; /** * Class for generic triples. * * @author Yin Lou * * @param the type of the 1st element. * @param the type of the 2nd element. * @param the type of the 3rd element. */ public class Triple { public T1 v1; public T2 v2; public T3 v3; /** * Constructor. * * @param v1 the 1st element. * @param v2 the 2nd element. * @param v3 the 3rd element. */ public Triple(T1 v1, T2 v2, T3 v3) { this.v1 = v1; this.v2 = v2; this.v3 = v3; } @Override public int hashCode() { final int prime = 31; int result = 1; result = prime * result + ((v1 == null) ? 0 : v1.hashCode()); result = prime * result + ((v2 == null) ? 0 : v2.hashCode()); result = prime * result + ((v3 == null) ? 0 : v3.hashCode()); return result; } @Override public boolean equals(Object obj) { if (this == obj) return true; if (obj == null) return false; if (getClass() != obj.getClass()) return false; Triple other = (Triple) obj; if (v1 == null) { if (other.v1 != null) return false; } else if (!v1.equals(other.v1)) return false; if (v2 == null) { if (other.v2 != null) return false; } else if (!v2.equals(other.v2)) return false; if (v3 == null) { if (other.v3 != null) return false; } else if (!v3.equals(other.v3)) return false; return true; } } ================================================ FILE: src/main/java/mltk/util/tuple/package-info.java ================================================ /** * Contains utility classes for tuples. */ package mltk.util.tuple; ================================================ FILE: src/test/java/mltk/core/BinsTest.java ================================================ package mltk.core; import org.junit.Assert; import org.junit.Test; import mltk.util.MathUtils; public class BinsTest { @Test public void testBins() { Bins bins = new Bins(new double[] {1, 5, 6}, new double[] {0.5, 2.5, 5.5}); Assert.assertEquals(0, bins.getIndex(-1)); Assert.assertEquals(0, bins.getIndex(0.3)); Assert.assertEquals(0, bins.getIndex(1)); Assert.assertEquals(1, bins.getIndex(1.1)); Assert.assertEquals(1, bins.getIndex(5)); Assert.assertEquals(2, bins.getIndex(5.5)); Assert.assertEquals(2, bins.getIndex(6.5)); Assert.assertEquals(0.5, bins.getValue(0), MathUtils.EPSILON); Assert.assertEquals(2.5, bins.getValue(1), MathUtils.EPSILON); Assert.assertEquals(5.5, bins.getValue(2), MathUtils.EPSILON); } } ================================================ FILE: src/test/java/mltk/core/InstancesTestHelper.java ================================================ package mltk.core; import java.util.ArrayList; import java.util.List; public class InstancesTestHelper { private static InstancesTestHelper instance = null; private Instances denseClaDataset; private Instances denseRegDataset; private Instances denseClaDatasetWMissing; private Instances denseRegDatasetWMissing; public static InstancesTestHelper getInstance() { if (instance == null) { instance = new InstancesTestHelper(); } return instance; } public Instances getDenseClassificationDataset() { return denseClaDataset; } public Instances getDenseRegressionDataset() { return denseRegDataset; } public Instances getDenseClassificationDatasetWMissing() { return denseClaDatasetWMissing; } public Instances getDenseRegressionDatasetWMissing() { return denseRegDatasetWMissing; } private InstancesTestHelper() { List attributes = new ArrayList<>(); NumericalAttribute f1 = new NumericalAttribute("f1", 0); NominalAttribute f2 = new NominalAttribute("f2", new String[] {"a", "b", "c"}, 1); BinnedAttribute f3 = new BinnedAttribute("f3", 256, 2); Bins bins = new Bins(new double[] {1, 5, 6}, new double[] {0.5, 2.5, 3}); BinnedAttribute f4 = new BinnedAttribute("f4", bins, 3); attributes.add(f1); attributes.add(f2); attributes.add(f3); attributes.add(f4); Attribute claTarget = new NominalAttribute("target", new String[] {"0", "1"}); Attribute regTarget = new NumericalAttribute("target"); denseClaDataset = new Instances(attributes, claTarget); for (int i = 0; i < 1000; i++) { double[] v = new double[4]; v[0] = i * 0.1; v[1] = i % f2.getCardinality(); v[2] = (i + 1000) % f3.getNumBins(); v[3] = i % f3.getNumBins(); double target = (i % 10) < 8 ? 0 : 1; Instance instance = new Instance(v, target); denseClaDataset.add(instance); } denseRegDataset = denseClaDataset.copy(); denseRegDataset.setTargetAttribute(regTarget); denseClaDatasetWMissing = denseClaDataset.copy(); for (int i = 0; i < 10; i++) { denseClaDatasetWMissing.get(i).setValue(0, Double.NaN); } denseRegDatasetWMissing = denseClaDatasetWMissing.copy(); denseRegDatasetWMissing.setTargetAttribute(regTarget); } } ================================================ FILE: src/test/java/mltk/core/io/AttributesReaderTest.java ================================================ package mltk.core.io; import java.io.BufferedReader; import java.io.ByteArrayInputStream; import java.io.ByteArrayOutputStream; import java.io.IOException; import java.io.InputStreamReader; import java.io.PrintWriter; import java.util.List; import org.junit.Assert; import org.junit.Test; import mltk.core.Attribute; import mltk.core.Attribute.Type; import mltk.core.BinnedAttribute; import mltk.core.NominalAttribute; import mltk.util.tuple.Pair; public class AttributesReaderTest { @Test public void testIO() { ByteArrayOutputStream boas = new ByteArrayOutputStream(); PrintWriter out = new PrintWriter(boas); out.println("f1: cont"); out.println("f2: {a, b, c}"); out.println("f3: binned (256)"); out.println("f4: binned (3;[1, 5, 6];[0.5, 2.5, 3])"); out.println("label: cont (target)"); out.flush(); out.close(); ByteArrayInputStream bais = new ByteArrayInputStream(boas.toByteArray()); BufferedReader br = new BufferedReader(new InputStreamReader(bais)); Pair, Attribute> pair = null; try { pair = AttributesReader.read(br); } catch (IOException e) { Assert.fail("Should not see exception: " + e.getMessage()); } List attributes = pair.v1; Attribute targetAtt = pair.v2; Assert.assertEquals("label", targetAtt.getName()); Assert.assertEquals(4, attributes.size()); for (int i = 0; i < attributes.size(); i++) { Assert.assertEquals(i, attributes.get(i).getIndex()); } Assert.assertEquals(Type.NUMERIC, attributes.get(0).getType()); Assert.assertEquals(Type.NOMINAL, attributes.get(1).getType()); Assert.assertEquals(Type.BINNED, attributes.get(2).getType()); Assert.assertEquals(Type.BINNED, attributes.get(3).getType()); Assert.assertArrayEquals(new String[] {"a", "b", "c"}, ((NominalAttribute) attributes.get(1)).getStates()); Assert.assertEquals(256, ((BinnedAttribute) attributes.get(2)).getNumBins()); Assert.assertEquals(3, ((BinnedAttribute) attributes.get(3)).getNumBins()); } } ================================================ FILE: src/test/java/mltk/core/io/InstancesReaderTest.java ================================================ package mltk.core.io; import org.junit.Assert; import org.junit.Test; import mltk.core.Instance; public class InstancesReaderTest { @Test public void testDenseFormat() { String[] data = {"0.0", "1.5", "?", "3"}; Instance instance = InstancesReader.parseDenseInstance(data, 3); Assert.assertTrue(instance.isMissing(2)); } } ================================================ FILE: src/test/java/mltk/core/processor/DiscretizerTest.java ================================================ package mltk.core.processor; import org.junit.Assert; import org.junit.Test; import mltk.core.BinnedAttribute; import mltk.core.Instances; import mltk.core.InstancesTestHelper; import mltk.util.MathUtils; public class DiscretizerTest { @Test public void testMissingValue() { Instances instances = InstancesTestHelper.getInstance().getDenseClassificationDatasetWMissing().copy(); Discretizer.discretize(instances, 0, 10); Assert.assertTrue(instances.getAttributes().get(0).getClass() == BinnedAttribute.class); for (int i = 0; i < 10; i++) { Assert.assertTrue(instances.get(i).isMissing(0)); } for (int i = 10; i < 20; i++) { Assert.assertFalse(instances.get(i).isMissing(0)); Assert.assertTrue(MathUtils.isInteger(instances.get(i).getValue(0))); } } } ================================================ FILE: src/test/java/mltk/core/processor/InstancesSplitterTest.java ================================================ package mltk.core.processor; import org.junit.Assert; import org.junit.Test; import mltk.core.Instance; import mltk.core.Instances; import mltk.core.InstancesTestHelper; import mltk.util.Random; public class InstancesSplitterTest { @Test public void testSamplingTrainValid() { Random.getInstance().setSeed(5); Instances instances = InstancesTestHelper.getInstance() .getDenseRegressionDataset(); Instances[] datasets = InstancesSplitter.split(instances, 0.8); Instances train = datasets[0]; Instances valid = datasets[1]; Assert.assertEquals((int) (instances.size() * 0.8), train.size()); Assert.assertEquals((int) (instances.size() * 0.2), valid.size()); } @Test public void testSamplingTrainValidTest() { Random.getInstance().setSeed(5); Instances instances = InstancesTestHelper.getInstance() .getDenseRegressionDataset(); Instances[] datasets = InstancesSplitter.split(instances, 0.7, 0.1, 0.1); Instances train = datasets[0]; Instances valid = datasets[1]; Instances test = datasets[2]; Assert.assertEquals((int) (instances.size() * 0.7), train.size()); Assert.assertEquals((int) (instances.size() * 0.1), valid.size()); Assert.assertEquals((int) (instances.size() * 0.1), test.size()); } @Test public void testStratifiedSampling() { Random.getInstance().setSeed(5); Instances instances = InstancesTestHelper.getInstance() .getDenseClassificationDataset(); Instances[] datasets = InstancesSplitter.split(instances, "target", 0.8, 0.2); Instances train = datasets[0]; Instances valid = datasets[1]; int numPosTrain = 0; for (Instance instance : train) { if (instance.getTarget() == 1) { numPosTrain++; } } int numPosValid = 0; for (Instance instance : valid) { if (instance.getTarget() == 1) { numPosValid++; } } Assert.assertTrue(numPosTrain < train.size() / 3); Assert.assertTrue(numPosValid < valid.size() / 3); Assert.assertEquals((int) (instances.size() * 0.8), train.size()); Assert.assertEquals((int) (instances.size() * 0.2), valid.size()); } } ================================================ FILE: src/test/java/mltk/predictor/evaluation/AUCTest.java ================================================ package mltk.predictor.evaluation; import mltk.util.MathUtils; import org.junit.Assert; import org.junit.Test; public class AUCTest { @Test public void test1() { double[] preds = {0.8, 0.1, 0.05, 0.9}; double[] targets = {1, 0, 0, 1}; AUC metric = new AUC(); Assert.assertEquals(1, metric.eval(preds, targets), MathUtils.EPSILON); } @Test public void test2() { double[] preds = {0.5, 0.5, 0.5, 0.5}; double[] targets = {1, 0, 0, 1}; AUC metric = new AUC(); Assert.assertEquals(0.5, metric.eval(preds, targets), MathUtils.EPSILON); } } ================================================ FILE: src/test/java/mltk/predictor/evaluation/ConvergenceTesterTest.java ================================================ package mltk.predictor.evaluation; import org.junit.Assert; import org.junit.Test; public class ConvergenceTesterTest { @Test public void test1() { ConvergenceTester ct = new ConvergenceTester(10, 0, 0.8); ct.setMetric(new AUC()); ct.add(0.55); ct.add(0.60); ct.add(0.70); ct.add(0.75); ct.add(0.80); ct.add(0.85); ct.add(0.90); // Peak ct.add(0.85); Assert.assertFalse(ct.isConverged()); ct.add(0.82); ct.add(0.81); Assert.assertTrue(ct.isConverged()); } @Test public void test2() { ConvergenceTester ct = new ConvergenceTester(10, 2, 1.0); ct.setMetric(new AUC()); ct.add(0.55); ct.add(0.60); ct.add(0.70); ct.add(0.75); ct.add(0.80); ct.add(0.85); ct.add(0.90); // Peak ct.add(0.85); Assert.assertFalse(ct.isConverged()); ct.add(0.82); ct.add(0.81); Assert.assertTrue(ct.isConverged()); } @Test public void test3() { ConvergenceTester ct = new ConvergenceTester(10, 0, 0.8); ct.setMetric(new RMSE()); ct.add(5.00); ct.add(4.50); ct.add(4.00); ct.add(3.50); ct.add(3.00); ct.add(2.50); ct.add(2.00); // Bottom ct.add(2.20); Assert.assertFalse(ct.isConverged()); ct.add(2.10); ct.add(2.05); Assert.assertTrue(ct.isConverged()); } @Test public void test4() { ConvergenceTester ct = new ConvergenceTester(10, 2, 1.0); ct.setMetric(new RMSE()); ct.add(5.00); ct.add(4.50); ct.add(4.00); ct.add(3.50); ct.add(3.00); ct.add(2.50); ct.add(2.00); // Bottom ct.add(2.20); Assert.assertFalse(ct.isConverged()); ct.add(2.10); ct.add(2.05); Assert.assertTrue(ct.isConverged()); } @Test public void testParse() { ConvergenceTester ct = null; // Empty ct = ConvergenceTester.parse(""); Assert.assertEquals(ct.minNumPoints, -1); Assert.assertEquals(ct.n, 0); Assert.assertEquals(ct.c, 1.0, 1e-6); // One parameter ct = ConvergenceTester.parse("10"); Assert.assertEquals(ct.minNumPoints, 10); Assert.assertEquals(ct.n, 0); Assert.assertEquals(ct.c, 1.0, 1e-6); // Two parameters ct = ConvergenceTester.parse("10:5"); Assert.assertEquals(ct.minNumPoints, 10); Assert.assertEquals(ct.n, 5); Assert.assertEquals(ct.c, 1.0, 1e-6); // Three parameters ct = ConvergenceTester.parse("10:5:0.8"); Assert.assertEquals(ct.minNumPoints, 10); Assert.assertEquals(ct.n, 5); Assert.assertEquals(ct.c, 0.8, 1e-6); } } ================================================ FILE: src/test/java/mltk/predictor/evaluation/ErrorTest.java ================================================ package mltk.predictor.evaluation; import mltk.util.MathUtils; import org.junit.Assert; import org.junit.Test; public class ErrorTest { @Test public void testLabel() { double[] preds = {1, 0, 1, 0}; double[] targets = {1, 0, 0, 1}; Error metric = new Error(); Assert.assertEquals(0.5, metric.eval(preds, targets), MathUtils.EPSILON); } @Test public void testProbability() { double[] preds = {2, -1.5, 0.3, 5}; double[] targets = {1, 0, 0, 1}; Error metric = new Error(); Assert.assertEquals(0.25, metric.eval(preds, targets), MathUtils.EPSILON); } } ================================================ FILE: src/test/java/mltk/predictor/evaluation/LogLossTest.java ================================================ package mltk.predictor.evaluation; import mltk.util.MathUtils; import org.junit.Assert; import org.junit.Test; public class LogLossTest { @Test public void testProb() { double[] preds = {0.8, 0.1, 0.05, 0.9}; double[] targets = {1, 0, 0, 1}; LogLoss metric = new LogLoss(false); Assert.assertEquals(0.485157877, metric.eval(preds, targets), MathUtils.EPSILON); } @Test public void testRawScore() { double[] preds = {5, -5, -3, 3}; double[] targets = {1, 0, 0, 1}; LogLoss metric = new LogLoss(true); Assert.assertEquals(0.1106054, metric.eval(preds, targets), MathUtils.EPSILON); } } ================================================ FILE: src/test/java/mltk/predictor/evaluation/LogisticLossTest.java ================================================ package mltk.predictor.evaluation; import mltk.util.MathUtils; import org.junit.Assert; import org.junit.Test; public class LogisticLossTest { @Test public void test() { double[] preds = {5, -5, -3, 3}; double[] targets = {1, 0, 0, 1}; LogisticLoss metric = new LogisticLoss(); Assert.assertEquals(0.02765135, metric.eval(preds, targets), MathUtils.EPSILON); } } ================================================ FILE: src/test/java/mltk/predictor/evaluation/MAETest.java ================================================ package mltk.predictor.evaluation; import mltk.util.MathUtils; import org.junit.Assert; import org.junit.Test; public class MAETest { @Test public void test() { double[] preds = {1, 2, 3, 4}; double[] targets = {0.1, 0.2, 0.3, 0.4}; MAE metric = new MAE(); Assert.assertEquals(2.25, metric.eval(preds, targets), MathUtils.EPSILON); } } ================================================ FILE: src/test/java/mltk/predictor/evaluation/MetricFactoryTest.java ================================================ package mltk.predictor.evaluation; import org.junit.Assert; import org.junit.Test; public class MetricFactoryTest { @Test public void test() { Assert.assertEquals(AUC.class, MetricFactory.getMetric("AUC").getClass()); Assert.assertEquals(Error.class, MetricFactory.getMetric("Error").getClass()); Assert.assertEquals(LogisticLoss.class, MetricFactory.getMetric("LogisticLoss").getClass()); Assert.assertEquals(LogLoss.class, MetricFactory.getMetric("LogLoss").getClass()); Assert.assertEquals(LogLoss.class, MetricFactory.getMetric("LogLoss:True").getClass()); Assert.assertEquals(true, ((LogLoss) MetricFactory.getMetric("LogLoss:True")).isRawScore()); Assert.assertEquals(MAE.class, MetricFactory.getMetric("MAE").getClass()); Assert.assertEquals(RMSE.class, MetricFactory.getMetric("RMSE").getClass()); } } ================================================ FILE: src/test/java/mltk/predictor/evaluation/RMSETest.java ================================================ package mltk.predictor.evaluation; import mltk.util.MathUtils; import org.junit.Assert; import org.junit.Test; public class RMSETest { @Test public void test() { double[] preds = {1, 2, 3, 4}; double[] targets = {1.1, 1.9, 3.2, 4}; RMSE metric = new RMSE(); Assert.assertEquals(0.122474487, metric.eval(preds, targets), MathUtils.EPSILON); } } ================================================ FILE: src/test/java/mltk/predictor/glm/GLMTest.java ================================================ package mltk.predictor.glm; import java.io.BufferedReader; import java.io.ByteArrayInputStream; import java.io.ByteArrayOutputStream; import java.io.InputStreamReader; import java.io.PrintWriter; import mltk.predictor.io.PredictorReader; import mltk.util.MathUtils; import org.junit.Assert; import org.junit.Test; public class GLMTest { @Test public void testIO() { double[] intercept = {1.0, -1.0}; double[][] w = { {1, 2, 3}, {-1, -2, -3} }; GLM glm = new GLM(intercept, w); ByteArrayOutputStream boas = new ByteArrayOutputStream(); PrintWriter out = new PrintWriter(boas); try { glm.write(out); } catch (Exception e) { Assert.fail("Should not see exception: " + e.getMessage()); } out.flush(); out.close(); ByteArrayInputStream bais = new ByteArrayInputStream(boas.toByteArray()); BufferedReader br = new BufferedReader(new InputStreamReader(bais)); try { GLM parsedGLM = PredictorReader.read(br, GLM.class); Assert.assertEquals(intercept.length, parsedGLM.intercept().length); Assert.assertEquals(w.length, parsedGLM.coefficients().length); Assert.assertArrayEquals(intercept, parsedGLM.intercept, MathUtils.EPSILON); for (int i = 0; i < intercept.length; i++) { Assert.assertArrayEquals(w[i], parsedGLM.coefficients(i), MathUtils.EPSILON); } } catch (Exception e) { Assert.fail("Should not see exception: " + e.getMessage()); } } } ================================================ FILE: src/test/java/mltk/predictor/tree/DecisionTableLearnerTest.java ================================================ package mltk.predictor.tree; import java.util.List; import org.junit.Assert; import org.junit.Test; import mltk.core.Attribute; import mltk.core.Instances; import mltk.core.InstancesTestHelper; public class DecisionTableLearnerTest { @Test public void testDecisionTableLearner1() { DecisionTableLearner rtLearner = new DecisionTableLearner(); rtLearner.setConstructionMode(DecisionTableLearner.Mode.ONE_PASS_GREEDY); rtLearner.setMaxDepth(2); Instances instances = InstancesTestHelper.getInstance().getDenseRegressionDataset(); DecisionTable rt = rtLearner.build(instances); int[] attributeIndices = rt.getAttributeIndices(); Assert.assertEquals(2, attributeIndices.length); Assert.assertEquals(0, attributeIndices[0]); } @Test public void testDecisionTableLearner2() { DecisionTableLearner rtLearner = new DecisionTableLearner(); rtLearner.setConstructionMode(DecisionTableLearner.Mode.ONE_PASS_GREEDY); rtLearner.setMaxDepth(2); Instances instances = InstancesTestHelper.getInstance() .getDenseRegressionDataset().copy(); // Apply feature selection List attributes = instances.getAttributes(1); instances.setAttributes(attributes); DecisionTable rt = rtLearner.build(instances); int[] attributeIndices = rt.getAttributeIndices(); Assert.assertEquals(2, attributeIndices.length); Assert.assertEquals(1, attributeIndices[0]); } } ================================================ FILE: src/test/java/mltk/predictor/tree/DecisionTableTest.java ================================================ package mltk.predictor.tree; import java.io.BufferedReader; import java.io.ByteArrayInputStream; import java.io.ByteArrayOutputStream; import java.io.InputStreamReader; import java.io.PrintWriter; import org.junit.Assert; import org.junit.Test; import mltk.core.Instance; import mltk.core.InstancesTestHelper; import mltk.predictor.io.PredictorReader; import mltk.util.MathUtils; public class DecisionTableTest { @Test public void testIO() { DecisionTable dt = DecisionTableTestHelper.getInstance().getTable1(); Instance instance = InstancesTestHelper.getInstance().getDenseRegressionDataset().get(0); try { ByteArrayOutputStream boas = new ByteArrayOutputStream(); PrintWriter out = new PrintWriter(boas); dt.write(out); out.close(); ByteArrayInputStream bais = new ByteArrayInputStream(boas.toByteArray()); BufferedReader in = new BufferedReader(new InputStreamReader(bais)); DecisionTable t = PredictorReader.read(in, DecisionTable.class); Assert.assertArrayEquals(dt.getAttributeIndices(), t.getAttributeIndices()); Assert.assertArrayEquals(dt.getSplits(), t.getSplits(), MathUtils.EPSILON); Assert.assertEquals(dt.regress(instance), t.regress(instance), MathUtils.EPSILON); } catch (Exception e) { Assert.fail("Should not see exception: " + e.getMessage()); } } @Test public void testRegress() { DecisionTable dt = DecisionTableTestHelper.getInstance().getTable1(); Instance instance = InstancesTestHelper.getInstance().getDenseRegressionDataset().get(0); Assert.assertEquals(0.7, dt.regress(instance), MathUtils.EPSILON); } } ================================================ FILE: src/test/java/mltk/predictor/tree/DecisionTableTestHelper.java ================================================ package mltk.predictor.tree; public class DecisionTableTestHelper { private static DecisionTableTestHelper instance = null; private DecisionTable dt1; private DecisionTable dt2; public static DecisionTableTestHelper getInstance() { if (instance == null) { instance = new DecisionTableTestHelper(); } return instance; } public DecisionTable getTable1() { return dt1; } public DecisionTable getTable2() { return dt2; } private DecisionTableTestHelper() { buildDecisionTable1(); buildDecisionTable2(); } private void buildDecisionTable1() { int[] attIndices = new int[] {0, 1, 2}; double[] splits = new double[] {50, 1.5, 23.5}; long[] predIndices = new long[] {0, 1, 2, 3, 4, 5, 6, 7}; double[] predValues = new double[] {0.1, 0.2, 0.3, 0.4, 0.5, 0.6, 0.7, 0.8}; dt1 = new DecisionTable(attIndices, splits, predIndices, predValues); } private void buildDecisionTable2() { int[] attIndices = new int[] {3, 2, 0}; double[] splits = new double[] {1, 56.5, 20}; long[] predIndices = new long[] {0, 1, 2, 3, 4, 5, 6, 7}; double[] predValues = new double[] {1.0, -0.9, 0.8, -0.7, 0.6, -0.5, 0.4, -0.3, 0.2}; dt2 = new DecisionTable(attIndices, splits, predIndices, predValues); } } ================================================ FILE: src/test/java/mltk/predictor/tree/RegressionTreeLearnerTest.java ================================================ package mltk.predictor.tree; import java.util.List; import org.junit.Assert; import org.junit.Test; import mltk.core.Attribute; import mltk.core.Instances; import mltk.core.InstancesTestHelper; public class RegressionTreeLearnerTest { @Test public void testRegressionTreeLearner1() { RegressionTreeLearner rtLearner = new RegressionTreeLearner(); rtLearner.setConstructionMode(RegressionTreeLearner.Mode.DEPTH_LIMITED); rtLearner.setMaxDepth(2); Instances instances = InstancesTestHelper.getInstance().getDenseRegressionDataset(); RegressionTree rt = rtLearner.build(instances); TreeInteriorNode root = (TreeInteriorNode) rt.getRoot(); Assert.assertEquals(0, root.attIndex); Assert.assertTrue(root.getLeftChild() != null); Assert.assertTrue(root.getRightChild() != null); } @Test public void testRegressionTreeLearner2() { RegressionTreeLearner rtLearner = new RegressionTreeLearner(); rtLearner.setConstructionMode(RegressionTreeLearner.Mode.DEPTH_LIMITED); rtLearner.setMaxDepth(2); Instances instances = InstancesTestHelper.getInstance() .getDenseRegressionDataset().copy(); // Apply feature selection List attributes = instances.getAttributes(1); instances.setAttributes(attributes); RegressionTree rt = rtLearner.build(instances); TreeInteriorNode root = (TreeInteriorNode) rt.getRoot(); Assert.assertEquals(1, root.attIndex); } } ================================================ FILE: src/test/java/mltk/predictor/tree/RegressionTreeTest.java ================================================ package mltk.predictor.tree; import java.io.BufferedReader; import java.io.ByteArrayInputStream; import java.io.ByteArrayOutputStream; import java.io.InputStreamReader; import java.io.PrintWriter; import org.junit.Assert; import org.junit.Test; import mltk.predictor.io.PredictorReader; import mltk.util.MathUtils; public class RegressionTreeTest { @Test public void testIO() { RegressionTree tree = RegressionTreeTestHelper.getInstance().getTree1(); try { ByteArrayOutputStream boas = new ByteArrayOutputStream(); PrintWriter out = new PrintWriter(boas); tree.write(out); out.close(); ByteArrayInputStream bais = new ByteArrayInputStream(boas.toByteArray()); BufferedReader in = new BufferedReader(new InputStreamReader(bais)); RegressionTree t = PredictorReader.read(in, RegressionTree.class); Assert.assertTrue(t.root instanceof TreeInteriorNode); TreeInteriorNode root = (TreeInteriorNode) t.root; Assert.assertEquals(1, root.getSplitAttributeIndex()); Assert.assertEquals(0.5, root.getSplitPoint(), MathUtils.EPSILON); Assert.assertTrue(root.left instanceof RegressionTreeLeaf); Assert.assertTrue(root.right instanceof TreeInteriorNode); RegressionTreeLeaf leaf1 = (RegressionTreeLeaf) root.left; Assert.assertEquals(0.4, leaf1.getPrediction(), MathUtils.EPSILON); TreeInteriorNode right = (TreeInteriorNode) root.right; Assert.assertEquals(2, right.getSplitAttributeIndex()); Assert.assertEquals(-1.5, right.getSplitPoint(), MathUtils.EPSILON); Assert.assertTrue(right.left.isLeaf()); Assert.assertTrue(right.right.isLeaf()); RegressionTreeLeaf leaf2 = (RegressionTreeLeaf) right.left; Assert.assertEquals(0.5, leaf2.getPrediction(), MathUtils.EPSILON); RegressionTreeLeaf leaf3 = (RegressionTreeLeaf) right.right; Assert.assertEquals(0.6, leaf3.getPrediction(), MathUtils.EPSILON); } catch (Exception e) { Assert.fail("Should not see exception: " + e.getMessage()); } } } ================================================ FILE: src/test/java/mltk/predictor/tree/RegressionTreeTestHelper.java ================================================ package mltk.predictor.tree; public class RegressionTreeTestHelper { private static RegressionTreeTestHelper instance = null; private RegressionTree tree1; private RegressionTree tree2; public static RegressionTreeTestHelper getInstance() { if (instance == null) { instance = new RegressionTreeTestHelper(); } return instance; } public RegressionTree getTree1() { return tree1; } public RegressionTree getTree2() { return tree2; } private RegressionTreeTestHelper() { buildTree1(); buildTree2(); } private void buildTree1() { TreeInteriorNode root = new TreeInteriorNode(1, 0.5); TreeNode leaf1 = new RegressionTreeLeaf(0.4); TreeInteriorNode right = new TreeInteriorNode(2, -1.5); TreeNode leaf2 = new RegressionTreeLeaf(0.5); TreeNode leaf3 = new RegressionTreeLeaf(0.6); right.left = leaf2; right.right = leaf3; root.left = leaf1; root.right = right; tree1 = new RegressionTree(root); } private void buildTree2() { TreeInteriorNode root = new TreeInteriorNode(5, 0); TreeNode leaf1 = new RegressionTreeLeaf(-0.4); TreeInteriorNode left = new TreeInteriorNode(0, -3.5); TreeNode leaf2 = new RegressionTreeLeaf(-0.5); TreeNode leaf3 = new RegressionTreeLeaf(-0.6); left.left = leaf1; left.right = leaf2; root.left = left; root.right = leaf3; tree2 = new RegressionTree(root); } } ================================================ FILE: src/test/java/mltk/predictor/tree/ensemble/BoostedDTablesTest.java ================================================ package mltk.predictor.tree.ensemble; import java.io.BufferedReader; import java.io.ByteArrayInputStream; import java.io.ByteArrayOutputStream; import java.io.InputStreamReader; import java.io.PrintWriter; import org.junit.Assert; import org.junit.Test; import mltk.predictor.tree.DecisionTable; import mltk.predictor.tree.DecisionTableTestHelper; import mltk.util.MathUtils; public class BoostedDTablesTest { @Test public void testIO() { DecisionTable dt1 = DecisionTableTestHelper.getInstance().getTable1(); DecisionTable dt2 = DecisionTableTestHelper.getInstance().getTable2(); BoostedDTables bt = new BoostedDTables(); bt.add(dt1); bt.add(dt2); try { ByteArrayOutputStream boas = new ByteArrayOutputStream(); PrintWriter out = new PrintWriter(boas); bt.write(out); out.close(); ByteArrayInputStream bais = new ByteArrayInputStream(boas.toByteArray()); BufferedReader in = new BufferedReader(new InputStreamReader(bais)); in.readLine(); BoostedDTables ts = new BoostedDTables(); ts.read(in); Assert.assertEquals(2, ts.size()); DecisionTable t1 = ts.get(0); Assert.assertArrayEquals(dt1.getAttributeIndices(), t1.getAttributeIndices()); Assert.assertArrayEquals(dt1.getSplits(), t1.getSplits(), MathUtils.EPSILON); DecisionTable t2 = ts.get(1); Assert.assertArrayEquals(dt2.getAttributeIndices(), t2.getAttributeIndices()); Assert.assertArrayEquals(dt2.getSplits(), t2.getSplits(), MathUtils.EPSILON); } catch (Exception e) { Assert.fail("Should not see exception: " + e.getMessage()); } } } ================================================ FILE: src/test/java/mltk/predictor/tree/ensemble/BoostedRTreesTest.java ================================================ package mltk.predictor.tree.ensemble; import java.io.BufferedReader; import java.io.ByteArrayInputStream; import java.io.ByteArrayOutputStream; import java.io.InputStreamReader; import java.io.PrintWriter; import org.junit.Assert; import org.junit.Test; import mltk.predictor.tree.RegressionTree; import mltk.predictor.tree.RegressionTreeLeaf; import mltk.predictor.tree.RegressionTreeTestHelper; import mltk.predictor.tree.TreeInteriorNode; import mltk.util.MathUtils; public class BoostedRTreesTest { @Test public void testIO() { RegressionTree tree1 = RegressionTreeTestHelper.getInstance().getTree1(); RegressionTree tree2 = RegressionTreeTestHelper.getInstance().getTree2(); BoostedRTrees bt = new BoostedRTrees(); bt.add(tree1); bt.add(tree2); try { ByteArrayOutputStream boas = new ByteArrayOutputStream(); PrintWriter out = new PrintWriter(boas); bt.write(out); out.close(); ByteArrayInputStream bais = new ByteArrayInputStream(boas.toByteArray()); BufferedReader in = new BufferedReader(new InputStreamReader(bais)); in.readLine(); BoostedRTrees ts = new BoostedRTrees(); ts.read(in); Assert.assertEquals(2, ts.size()); Assert.assertTrue(ts.get(0) instanceof RegressionTree); Assert.assertTrue(ts.get(1) instanceof RegressionTree); RegressionTree t1 = (RegressionTree) ts.get(0); Assert.assertTrue(t1.getRoot() instanceof TreeInteriorNode); TreeInteriorNode root = (TreeInteriorNode) t1.getRoot(); Assert.assertEquals(1, root.getSplitAttributeIndex()); Assert.assertEquals(0.5, root.getSplitPoint(), MathUtils.EPSILON); Assert.assertTrue(root.getLeftChild() instanceof RegressionTreeLeaf); Assert.assertTrue(root.getRightChild() instanceof TreeInteriorNode); } catch (Exception e) { Assert.fail("Should not see exception: " + e.getMessage()); } } } ================================================ FILE: src/test/java/mltk/predictor/tree/ensemble/brt/BDTTest.java ================================================ package mltk.predictor.tree.ensemble.brt; import java.io.BufferedReader; import java.io.ByteArrayInputStream; import java.io.ByteArrayOutputStream; import java.io.InputStreamReader; import java.io.PrintWriter; import org.junit.Assert; import org.junit.Test; import mltk.core.Instance; import mltk.core.Instances; import mltk.core.InstancesTestHelper; import mltk.predictor.io.PredictorReader; import mltk.predictor.tree.DecisionTable; import mltk.predictor.tree.DecisionTableTestHelper; import mltk.predictor.tree.ensemble.BoostedDTables; import mltk.predictor.tree.ensemble.BoostedRTrees; import mltk.util.MathUtils; public class BDTTest { private BDT bdt; public BDTTest() { DecisionTable dt1 = DecisionTableTestHelper.getInstance().getTable1(); DecisionTable dt2 = DecisionTableTestHelper.getInstance().getTable2(); BoostedRTrees bt = new BoostedRTrees(); bt.add(dt1); bt.add(dt2); bdt = new BDT(1); bdt.tables[0] = new BoostedDTables(bt); } @Test public void testIO() { DecisionTable dt1 = DecisionTableTestHelper.getInstance().getTable1(); DecisionTable dt2 = DecisionTableTestHelper.getInstance().getTable2(); try { ByteArrayOutputStream boas = new ByteArrayOutputStream(); PrintWriter out = new PrintWriter(boas); bdt.write(out); out.close(); ByteArrayInputStream bais = new ByteArrayInputStream(boas.toByteArray()); BufferedReader in = new BufferedReader(new InputStreamReader(bais)); BDT b = PredictorReader.read(in, BDT.class); BoostedDTables ts = b.getDecisionTreeList(0); Assert.assertEquals(2, ts.size()); DecisionTable t1 = ts.get(0); Assert.assertArrayEquals(dt1.getAttributeIndices(), t1.getAttributeIndices()); Assert.assertArrayEquals(dt1.getSplits(), t1.getSplits(), MathUtils.EPSILON); DecisionTable t2 = ts.get(1); Assert.assertArrayEquals(dt2.getAttributeIndices(), t2.getAttributeIndices()); Assert.assertArrayEquals(dt2.getSplits(), t2.getSplits(), MathUtils.EPSILON); } catch (Exception e) { Assert.fail("Should not see exception: " + e.getMessage()); } } @Test public void testRegress() { DecisionTable dt1 = DecisionTableTestHelper.getInstance().getTable1(); DecisionTable dt2 = DecisionTableTestHelper.getInstance().getTable2(); BoostedRTrees bt = new BoostedRTrees(); bt.add(dt1); bt.add(dt2); BRT brt = new BRT(1); brt.trees[0] = bt; Instances instances = InstancesTestHelper.getInstance().getDenseRegressionDataset(); for (Instance instance : instances) { Assert.assertEquals(brt.regress(instance), bdt.regress(instance), MathUtils.EPSILON); } } } ================================================ FILE: src/test/java/mltk/predictor/tree/ensemble/brt/BRTTest.java ================================================ package mltk.predictor.tree.ensemble.brt; import java.io.BufferedReader; import java.io.ByteArrayInputStream; import java.io.ByteArrayOutputStream; import java.io.InputStreamReader; import java.io.PrintWriter; import org.junit.Assert; import org.junit.Test; import mltk.predictor.io.PredictorReader; import mltk.predictor.tree.RegressionTree; import mltk.predictor.tree.RegressionTreeLeaf; import mltk.predictor.tree.RegressionTreeTestHelper; import mltk.predictor.tree.TreeInteriorNode; import mltk.predictor.tree.ensemble.BoostedRTrees; import mltk.util.MathUtils; public class BRTTest { @Test public void testIO() { RegressionTree tree1 = RegressionTreeTestHelper.getInstance().getTree1(); RegressionTree tree2 = RegressionTreeTestHelper.getInstance().getTree2(); BoostedRTrees bt = new BoostedRTrees(); bt.add(tree1); bt.add(tree2); BRT brt = new BRT(1); brt.trees[0] = bt; try { ByteArrayOutputStream boas = new ByteArrayOutputStream(); PrintWriter out = new PrintWriter(boas); brt.write(out); out.close(); ByteArrayInputStream bais = new ByteArrayInputStream(boas.toByteArray()); BufferedReader in = new BufferedReader(new InputStreamReader(bais)); BRT b = PredictorReader.read(in, BRT.class); BoostedRTrees ts = b.getRegressionTreeList(0); Assert.assertEquals(2, ts.size()); Assert.assertTrue(ts.get(0) instanceof RegressionTree); Assert.assertTrue(ts.get(1) instanceof RegressionTree); RegressionTree t1 = (RegressionTree) ts.get(0); Assert.assertTrue(t1.getRoot() instanceof TreeInteriorNode); TreeInteriorNode root = (TreeInteriorNode) t1.getRoot(); Assert.assertEquals(1, root.getSplitAttributeIndex()); Assert.assertEquals(0.5, root.getSplitPoint(), MathUtils.EPSILON); Assert.assertTrue(root.getLeftChild() instanceof RegressionTreeLeaf); Assert.assertTrue(root.getRightChild() instanceof TreeInteriorNode); } catch (Exception e) { Assert.fail("Should not see exception: " + e.getMessage()); } } } ================================================ FILE: src/test/java/mltk/predictor/tree/ensemble/brt/BRTUtilsTest.java ================================================ package mltk.predictor.tree.ensemble.brt; import org.junit.Assert; import org.junit.Test; import mltk.predictor.tree.RegressionTreeLearner; import mltk.predictor.tree.TreeLearner; import mltk.util.MathUtils; public class BRTUtilsTest { @Test public void testParseRegressionTreeLearner1() { String baseLearner = "rt:l:100"; TreeLearner treeLearner = null; treeLearner = BRTUtils.parseTreeLearner(baseLearner); Assert.assertTrue(treeLearner instanceof RegressionTreeLearner); RegressionTreeLearner rtLearner = (RegressionTreeLearner) treeLearner; Assert.assertEquals(RegressionTreeLearner.Mode.NUM_LEAVES_LIMITED, rtLearner.getConstructionMode()); Assert.assertEquals(100, rtLearner.getMaxNumLeaves()); } @Test public void testParseRegressionTreeLearner2() { String baseLearner = "rt:d:5"; TreeLearner treeLearner = null; treeLearner = BRTUtils.parseTreeLearner(baseLearner); Assert.assertTrue(treeLearner instanceof RegressionTreeLearner); RegressionTreeLearner rtLearner = (RegressionTreeLearner) treeLearner; Assert.assertEquals(RegressionTreeLearner.Mode.DEPTH_LIMITED, rtLearner.getConstructionMode()); Assert.assertEquals(5, rtLearner.getMaxDepth()); } @Test public void testParseRegressionTreeLearner3() { String baseLearner = "rt:a:0.01"; TreeLearner treeLearner = null; treeLearner = BRTUtils.parseTreeLearner(baseLearner); Assert.assertTrue(treeLearner instanceof RegressionTreeLearner); RegressionTreeLearner rtLearner = (RegressionTreeLearner) treeLearner; Assert.assertEquals(RegressionTreeLearner.Mode.ALPHA_LIMITED, rtLearner.getConstructionMode()); Assert.assertEquals(0.01, rtLearner.getAlpha(), MathUtils.EPSILON); } @Test public void testParseRegressionTreeLearner4() { String baseLearner = "rt:s:50"; TreeLearner treeLearner = null; treeLearner = BRTUtils.parseTreeLearner(baseLearner); Assert.assertTrue(treeLearner instanceof RegressionTreeLearner); RegressionTreeLearner rtLearner = (RegressionTreeLearner) treeLearner; Assert.assertEquals(RegressionTreeLearner.Mode.MIN_LEAF_SIZE_LIMITED, rtLearner.getConstructionMode()); Assert.assertEquals(50, rtLearner.getMinLeafSize()); } @Test public void testParseRobustRegressionTreeLearner1() { String baseLearner = "rrt:l:100"; TreeLearner treeLearner = null; treeLearner = BRTUtils.parseTreeLearner(baseLearner); Assert.assertTrue(treeLearner instanceof RegressionTreeLearner); RegressionTreeLearner rtLearner = (RegressionTreeLearner) treeLearner; Assert.assertEquals(RegressionTreeLearner.Mode.NUM_LEAVES_LIMITED, rtLearner.getConstructionMode()); Assert.assertEquals(100, rtLearner.getMaxNumLeaves()); } @Test public void testParseRobustRegressionTreeLearner2() { String baseLearner = "rrt:d:5"; TreeLearner treeLearner = null; treeLearner = BRTUtils.parseTreeLearner(baseLearner); Assert.assertTrue(treeLearner instanceof RegressionTreeLearner); RegressionTreeLearner rtLearner = (RegressionTreeLearner) treeLearner; Assert.assertEquals(RegressionTreeLearner.Mode.DEPTH_LIMITED, rtLearner.getConstructionMode()); Assert.assertEquals(5, rtLearner.getMaxDepth()); } @Test public void testParseRobustRegressionTreeLearner3() { String baseLearner = "rrt:a:0.01"; TreeLearner treeLearner = null; treeLearner = BRTUtils.parseTreeLearner(baseLearner); Assert.assertTrue(treeLearner instanceof RegressionTreeLearner); RegressionTreeLearner rtLearner = (RegressionTreeLearner) treeLearner; Assert.assertEquals(RegressionTreeLearner.Mode.ALPHA_LIMITED, rtLearner.getConstructionMode()); Assert.assertEquals(0.01, rtLearner.getAlpha(), MathUtils.EPSILON); } @Test public void testParseRobustRegressionTreeLearner4() { String baseLearner = "rrt:s:50"; TreeLearner treeLearner = null; treeLearner = BRTUtils.parseTreeLearner(baseLearner); Assert.assertTrue(treeLearner instanceof RegressionTreeLearner); RegressionTreeLearner rtLearner = (RegressionTreeLearner) treeLearner; Assert.assertEquals(RegressionTreeLearner.Mode.MIN_LEAF_SIZE_LIMITED, rtLearner.getConstructionMode()); Assert.assertEquals(50, rtLearner.getMinLeafSize()); } } ================================================ FILE: src/test/java/mltk/predictor/tree/ensemble/brt/LogitBoostLearnerTest.java ================================================ package mltk.predictor.tree.ensemble.brt; import org.junit.Assert; import org.junit.Test; import mltk.core.Instances; import mltk.core.InstancesTestHelper; import mltk.predictor.evaluation.Evaluator; import mltk.predictor.evaluation.MetricFactory; import mltk.predictor.tree.TreeLearner; public class LogitBoostLearnerTest { @Test public void testLogitBoostLearner() { TreeLearner treeLearner = BRTUtils.parseTreeLearner("rrt:d:3"); Instances instances = InstancesTestHelper.getInstance().getDenseClassificationDataset(); LogitBoostLearner learner = new LogitBoostLearner(); learner.setLearningRate(0.1); learner.setMetric(MetricFactory.getMetric("auc")); learner.setTreeLearner(treeLearner); BRT brt = learner.buildBinaryClassifier(instances, 10); double auc = Evaluator.evalAreaUnderROC(brt, instances); Assert.assertTrue(auc > 0.5); } } ================================================ FILE: src/test/java/mltk/util/ArrayUtilsTest.java ================================================ package mltk.util; import org.junit.Assert; import org.junit.Test; public class ArrayUtilsTest { @Test public void testParseDoubleArray() { String str = "[1.1, 2.2, 3.3, 4.4]"; double[] a = {1.1, 2.2, 3.3, 4.4}; Assert.assertArrayEquals(a, ArrayUtils.parseDoubleArray(str), MathUtils.EPSILON); } @Test public void testParseIntArray() { String str = "[1, 2, 3, 4]"; int[] a = {1, 2, 3, 4}; Assert.assertArrayEquals(a, ArrayUtils.parseIntArray(str)); } @Test public void testIsConstant() { int[] a = {1, 1, 1}; int[] b = {2, 1, 1}; Assert.assertTrue(ArrayUtils.isConstant(a, 0, a.length, 1)); Assert.assertFalse(ArrayUtils.isConstant(b, 0, b.length, 1)); Assert.assertTrue(ArrayUtils.isConstant(b, 1, b.length, 1)); double[] c = {0.1, 0.1, 0.1, 0.1}; double[] d = {0.2, 0.1, 0.1, 0.1}; Assert.assertTrue(ArrayUtils.isConstant(c, 0, c.length, 0.1)); Assert.assertFalse(ArrayUtils.isConstant(d, 0, d.length, 0.1)); Assert.assertTrue(ArrayUtils.isConstant(d, 1, d.length, 0.1)); } @Test public void testGetMedian() { double[] a = {0.7, 0.4, 0.3, 0.2, 0.5, 0.6, 0.1}; Assert.assertEquals(0.4, ArrayUtils.getMedian(a), MathUtils.EPSILON); } } ================================================ FILE: src/test/java/mltk/util/MathUtilsTest.java ================================================ package mltk.util; import org.junit.Assert; import org.junit.Test; public class MathUtilsTest { @Test public void testEquals() { Assert.assertTrue(MathUtils.equals(0.1, 0.10000001)); Assert.assertFalse(MathUtils.equals(0.0, 1.0)); } @Test public void testIndicator() { Assert.assertEquals(1, MathUtils.indicator(true)); Assert.assertEquals(0, MathUtils.indicator(false)); } @Test public void testIsFirstBetter() { Assert.assertTrue(MathUtils.isFirstBetter(0.5, 0, true)); Assert.assertFalse(MathUtils.isFirstBetter(0.5, 0, false)); } @Test public void testIsInteger() { Assert.assertTrue(MathUtils.isInteger(1.0)); Assert.assertFalse(MathUtils.isInteger(1.1)); } @Test public void testIsZero() { Assert.assertTrue(MathUtils.isZero(MathUtils.EPSILON / 2)); Assert.assertFalse(MathUtils.isZero(MathUtils.EPSILON * 2)); } @Test public void testSigmoid() { Assert.assertEquals(0.5, MathUtils.sigmoid(0), MathUtils.EPSILON); } @Test public void testSign() { Assert.assertEquals(1, MathUtils.sign(0.5)); Assert.assertEquals(0, MathUtils.sign(0.0)); Assert.assertEquals(-1, MathUtils.sign(-0.5)); Assert.assertEquals(1, MathUtils.sign(2)); Assert.assertEquals(0, MathUtils.sign(0)); Assert.assertEquals(-1, MathUtils.sign(-2)); } } ================================================ FILE: src/test/java/mltk/util/OptimUtilsTest.java ================================================ package mltk.util; import org.junit.Assert; import org.junit.Test; public class OptimUtilsTest { @Test public void testGetProbability() { Assert.assertEquals(0.5, OptimUtils.getProbability(0), MathUtils.EPSILON); } @Test public void testGetResidual() { Assert.assertEquals(-1.0, OptimUtils.getResidual(1.0, 0), MathUtils.EPSILON); } @Test public void testGetPseudoResidual() { Assert.assertEquals(0.5, OptimUtils.getPseudoResidual(0, 1), MathUtils.EPSILON); Assert.assertEquals(-0.5, OptimUtils.getPseudoResidual(0, 0), MathUtils.EPSILON); } @Test public void testComputeLogisticLoss() { Assert.assertEquals(0.693147181, OptimUtils.computeLogisticLoss(0, 1), MathUtils.EPSILON); Assert.assertEquals(0.693147181, OptimUtils.computeLogisticLoss(0, -1), MathUtils.EPSILON); Assert.assertEquals(0.006715348, OptimUtils.computeLogisticLoss(5, 1), MathUtils.EPSILON); Assert.assertEquals(5.006715348, OptimUtils.computeLogisticLoss(5, -1), MathUtils.EPSILON); } @Test public void testIsConverged() { Assert.assertTrue(OptimUtils.isConverged(0.100000001, 0.1, 1e-6)); Assert.assertFalse(OptimUtils.isConverged(0.15, 0.1, 1e-6)); } } ================================================ FILE: src/test/java/mltk/util/StatUtilsTest.java ================================================ package mltk.util; import org.junit.Assert; import org.junit.Test; public class StatUtilsTest { private int[] a = {1, 4, 3, 2}; private double[] b = {-1.2, 1.2, -5.3, 5.3}; @Test public void testMax() { Assert.assertEquals(4, StatUtils.max(a)); Assert.assertEquals(5.3, StatUtils.max(b), MathUtils.EPSILON); } @Test public void testIndexOfMax() { Assert.assertEquals(1, StatUtils.indexOfMax(a)); Assert.assertEquals(3, StatUtils.indexOfMax(b)); } @Test public void testMin() { Assert.assertEquals(1, StatUtils.min(a)); Assert.assertEquals(-5.3, StatUtils.min(b), MathUtils.EPSILON); } @Test public void testIndexOfMin() { Assert.assertEquals(0, StatUtils.indexOfMin(a)); Assert.assertEquals(2, StatUtils.indexOfMin(b)); } @Test public void testSum() { Assert.assertEquals(0, StatUtils.sum(b), MathUtils.EPSILON); } @Test public void testSumSq() { Assert.assertEquals(59.06, StatUtils.sumSq(b), MathUtils.EPSILON); Assert.assertEquals(b[0] * b[0], StatUtils.sumSq(b, 0, 1), MathUtils.EPSILON); } @Test public void testMean() { Assert.assertEquals(0, StatUtils.mean(b), MathUtils.EPSILON); } @Test public void testVariance() { Assert.assertEquals(19.686666667, StatUtils.variance(b), MathUtils.EPSILON); } @Test public void testStd() { Assert.assertEquals(Math.sqrt(19.686666667), StatUtils.sd(b), MathUtils.EPSILON); } @Test public void testRms() { Assert.assertEquals(Math.sqrt(59.06 / b.length), StatUtils.rms(b), MathUtils.EPSILON); } } ================================================ FILE: src/test/java/mltk/util/VectorUtilsTest.java ================================================ package mltk.util; import org.junit.Assert; import org.junit.Test; public class VectorUtilsTest { @Test public void testAdd() { double[] a = {1, 2, 3, 4}; double[] b = {2, 3, 4, 5}; VectorUtils.add(a, 1); Assert.assertArrayEquals(b, a, MathUtils.EPSILON); } @Test public void testSubtract() { double[] a = {1, 2, 3, 4}; double[] b = {2, 3, 4, 5}; VectorUtils.subtract(b, 1); Assert.assertArrayEquals(a, b, MathUtils.EPSILON); } @Test public void testMultiply() { double[] a = {1, 2, 3, 4}; double[] b = {2, 4, 6, 8}; VectorUtils.multiply(a, 2); Assert.assertArrayEquals(b, a, MathUtils.EPSILON); } @Test public void testDivide() { double[] a = {1, 2, 3, 4}; double[] b = {2, 4, 6, 8}; VectorUtils.divide(b, 2); Assert.assertArrayEquals(a, b, MathUtils.EPSILON); } @Test public void testL2norm() { double[] a = {1, 2, 3, 4}; Assert.assertEquals(5.477225575, VectorUtils.l2norm(a), MathUtils.EPSILON); } @Test public void testL1norm() { double[] a = {1, -2, 3, -4}; Assert.assertEquals(10, VectorUtils.l1norm(a), MathUtils.EPSILON); } @Test public void testDotProduct() { double[] a = {1, 2, 3, 4}; double[] b = {0, -1, 0, 1}; Assert.assertEquals(2, VectorUtils.dotProduct(a, b), MathUtils.EPSILON); } @Test public void testCorrelation() { double[] a = {1, 2, 3, 4}; double[] b = {2, 4, 6, 8}; double[] c = {-2, -4, -6, -8}; Assert.assertEquals(1, VectorUtils.correlation(a, b), MathUtils.EPSILON); Assert.assertEquals(-1, VectorUtils.correlation(a, c), MathUtils.EPSILON); } }