Repository: etsy/Conjecture Branch: master Commit: a32d61966b12 Files: 117 Total size: 449.7 KB Directory structure: gitextract__actrptd/ ├── .gitignore ├── .travis.yml ├── LICENSE.md ├── README.md ├── bin/ │ ├── demo.sh │ ├── model_diff.py │ ├── model_param.py │ └── prediction_inspection.py ├── build.sbt ├── clients/ │ └── phplib/ │ └── Conjecture/ │ ├── BinaryClassifier.php │ ├── Config.php │ ├── ConjectureException.php │ ├── Finder.php │ ├── Instance.php │ ├── MulticlassClassifier.php │ ├── MulticlassLogisticRegressionClassifier.php │ ├── MulticlassOneVsAllClassifier.php │ ├── Text.php │ ├── TextSequence.php │ └── Vector.php ├── data/ │ └── iris.tsv ├── project/ │ ├── build.properties │ └── plugins.sbt ├── sbt └── src/ ├── main/ │ ├── java/ │ │ └── com/ │ │ └── etsy/ │ │ └── conjecture/ │ │ ├── GenericPair.java │ │ ├── PrimitivePair.java │ │ ├── Utilities.java │ │ ├── data/ │ │ │ ├── AbstractInstance.java │ │ │ ├── BinaryLabel.java │ │ │ ├── BinaryLabeledInstance.java │ │ │ ├── ByteArrayDoubleHashMap.java │ │ │ ├── ClusterLabel.java │ │ │ ├── ClusterPrediction.java │ │ │ ├── Instance.java │ │ │ ├── InstanceFactory.java │ │ │ ├── InstanceInterface.java │ │ │ ├── Label.java │ │ │ ├── LabeledInstance.java │ │ │ ├── LazyVector.java │ │ │ ├── MulticlassLabel.java │ │ │ ├── MulticlassLabeledInstance.java │ │ │ ├── MulticlassPrediction.java │ │ │ ├── RealValueLabeledInstance.java │ │ │ ├── RealValuedLabel.java │ │ │ ├── Recommendation.java │ │ │ └── StringKeyedVector.java │ │ ├── evaluation/ │ │ │ ├── BinaryModelEvaluation.java │ │ │ ├── ConfusionMatrix.java │ │ │ ├── EvaluationAggregator.java │ │ │ ├── ModelEvaluation.java │ │ │ ├── MulticlassConfusionMatrix.java │ │ │ ├── MulticlassModelEvaluation.java │ │ │ ├── MulticlassReceiverOperatingCharacteristic.java │ │ │ ├── ReceiverOperatingCharacteristic.java │ │ │ └── RegressionModelEvaluation.java │ │ ├── model/ │ │ │ ├── AdagradOptimizer.java │ │ │ ├── ClusteringModel.java │ │ │ ├── ControlOptimizer.java │ │ │ ├── Decomposable.java │ │ │ ├── ElasticNetOptimizer.java │ │ │ ├── FTRLOptimizer.java │ │ │ ├── Hinge.java │ │ │ ├── KMeans.java │ │ │ ├── LeastSquaresRegressionModel.java │ │ │ ├── LogisticRegression.java │ │ │ ├── MIRA.java │ │ │ ├── MIRAOptimizer.java │ │ │ ├── Model.java │ │ │ ├── PassiveAggressiveOptimizer.java │ │ │ ├── SGDOptimizer.java │ │ │ ├── UpdateableLinearModel.java │ │ │ ├── UpdateableModel.java │ │ │ └── UpdateableMulticlassLinearModel.java │ │ └── topics/ │ │ └── lda/ │ │ ├── LDADenseTopics.java │ │ ├── LDADict.java │ │ ├── LDADoc.java │ │ ├── LDAPartialSparseTopics.java │ │ ├── LDAPartialTopics.java │ │ ├── LDARandomTopics.java │ │ ├── LDASparseTopics.java │ │ ├── LDATopics.java │ │ └── LDAUtils.java │ └── scala/ │ └── com/ │ └── etsy/ │ ├── conjecture/ │ │ ├── VWReader.scala │ │ ├── demo/ │ │ │ ├── DemoLinearHyperparameterSearch.scala │ │ │ ├── IrisDataToMulticlassLabeledInstances.scala │ │ │ └── LearnMulticlassClassifier.scala │ │ ├── scalding/ │ │ │ ├── ALSJob.scala │ │ │ ├── FastKNN.scala │ │ │ ├── LSH.scala │ │ │ ├── NNMF.scala │ │ │ ├── SVD.scala │ │ │ ├── evaluate/ │ │ │ │ ├── GenericCrossValidator.scala │ │ │ │ └── GenericEvaluator.scala │ │ │ ├── factorize/ │ │ │ │ └── FactorizationTools.scala │ │ │ ├── train/ │ │ │ │ ├── AbstractModelTrainer.scala │ │ │ │ ├── BinaryModelTrainer.scala │ │ │ │ ├── ClusteringModelTrainer.scala │ │ │ │ ├── LargeModelTrainer.scala │ │ │ │ ├── ModelTrainerStrategy.scala │ │ │ │ ├── MulticlassModelTrainer.scala │ │ │ │ ├── RegressionModelTrainer.scala │ │ │ │ └── SmallModelTrainer.scala │ │ │ └── util/ │ │ │ ├── BaseGridSearcher.scala │ │ │ ├── DynamicOptions.scala │ │ │ └── HyperparameterSearcher.scala │ │ └── text/ │ │ ├── FeatureHelper.scala │ │ ├── Text.scala │ │ └── TextSequence.scala │ └── scalding/ │ └── jobs/ │ └── conjecture/ │ ├── AdHocClassifier.scala │ ├── AdHocClusterer.scala │ ├── AdHocMulticlassClassifier.scala │ ├── AdHocPredictor.scala │ └── NNMFTest.scala └── test/ └── java/ └── com/ └── etsy/ └── conjecture/ ├── data/ │ ├── LazyVectorTest.java │ └── StringKeyedVectorTest.java ├── evaluation/ │ └── TestReceiverOperatingCharacteristic.java └── model/ └── UpdateableLinearModelTest.java ================================================ FILE CONTENTS ================================================ ================================================ FILE: .gitignore ================================================ *.class *.log *.swp *.swo # sbt specific dist/* target/ lib_managed/ src_managed/ project/boot/ project/plugins/project/ # Scala-IDE specific .scala_dependencies #java *.class # Package Files # *.jar *.war *.ear *~ *\# .history .idea ================================================ FILE: .travis.yml ================================================ sudo: false language: scala script: - sbt +test ================================================ FILE: LICENSE.md ================================================ The MIT License =============== Copyright (c) 2009 Anton Grigoryev Permission is hereby granted, free of charge, to any person obtaining a copy of this software and associated documentation files (the "Software"), to deal in the Software without restriction, including without limitation the rights to use, copy, modify, merge, publish, distribute, sublicense, and/or sell copies of the Software, and to permit persons to whom the Software is furnished to do so, subject to the following conditions: The above copyright notice and this permission notice shall be included in all copies or substantial portions of the Software. THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE. ================================================ FILE: README.md ================================================ # Conjecture [![Build Status](https://travis-ci.org/etsy/Conjecture.svg?branch=master)](https://travis-ci.org/etsy/Conjecture) Conjecture is a framework for building machine learning models in Hadoop using the Scalding DSL. The goal of this project is to enable the development of statistical models as viable components in a wide range of product settings. Applications include classification and categorization, recommender systems, ranking, filtering, and regression (predicting real-valued numbers). Conjecture has been designed with a primary emphasis on flexibility and can handle a wide variety of inputs. Integration with Hadoop and scalding enable seamless handling of extremely large data volumes, and integration with established ETL processes. Predicted labels can either be consumed directly by the web stack using the dataset loader, or models can be deployed and consumed by live web code. Currently, binary classification (assigning one of two possible labels to input data points) is the most mature component of the Conjecture package. # Tutorial There are a few stages involved in training a machine learning model using Conjecture. ## Create Training Data We represent the training data as "feature vectors" which are just mappings of feature names to real values. In this case we represent them as a java map of strings to doubles (although we have a class StringKeyedVector which provides convenience methods for feature vector construction). We also need the true label of each instance, which we represent as 0 and 1 (the mapping of these binary labels to e.g., "male" and "female" is up to the user). We construct BinaryLabeledInstances, which are just wrappers for a feature vector and a label. val bl = new BinaryLabeledInstance(0.0) bl.addTerm("bias", 1.0) bl.addTerm("some_feature", 0.5) ## Training a Classifier Classifiers are essentially trained by presenting the labeled instances to them. There are several kinds of linear classifiers we implement, among them: * Logistic regression, * Perceptron, * MIRA (a large margin perceptron model), * Passive aggressive. These models all have several options, such as learning rate, regularization parameters and so on. We supply reasonable defaults for these parameters although they can be changed readily. To train a linear model simply call the update function with the labeled instance: val p = new LogisticRegression() p.update(bl) In order to make this procedure tractable for large datasets, we provided scalding wrappers for the training. These operate by training several small models on mappers, then aggregating them into a final complete model on the reducers. This wrapper is called like so: new BinaryModelTrainer(args) .train(instances, 'instance, 'model) .write(SequenceFile("model")) .map('model -> 'model){ x : UpdateableBinaryModel => new com.google.gson.Gson.toJson(x) } .write(Tsv("model_json")) This code segment will train a model using a pipe called instances which has a field called instance which contains the BinaryLabeledInstance objects. It produces a pipe with a single field containing the completed model, which can then be written to disk. This class uses the command line args object from scalding, in order to let you set some options on the command line. Some useful options are: | Argument | Possible values | Default | Meaning | |-------------------------------------|-----------------------------------------------|--------------------|--------------------------------------------------| | --model | mira, logistic_regression, passive_aggressive | passive_aggressive | The type of model to use. | | --iters | 1, 2, 3... | 1 | The number of iterations of training to perform. | | --zero_class_prob, --one_class_prob | [0, 1] | 1 | | To see all the command line options, see the BinaryModelTrainer class. ## Evaluating a Classifier It is important to get a sense of the performance you can expect out of your classifier on unseen data. In order to do this we recommend to use cross validation. In essence, your input set of instances is split up into testing and training portions (multiple different ways), then a classifier is trained on each training portion, and evaluated (against the true labels which are present) using the testing portion. This is all wrapped up in a class called BinaryCrossValidator, it is used like so: new BinaryCrossValidator(args, 5) .crossValidate(instances, 'instance) .write(Tsv("model_xval")) This class also takes the command line arguments, which it passes to a model trainer for each fold. This allows the specification of options to the cross validated models on the command line. The output contains statistics about the performance of the model as well as the confusion matrices for each fold. A script is included which cross validates a logistic regression model on the iris dataset. ================================================ FILE: bin/demo.sh ================================================ #!/bin/bash # - make monolithic conjecture jar. sbt clean assembly # - make the instances. java -cp target/conjecture-assembly-*.jar com.twitter.scalding.Tool com.etsy.conjecture.demo.IrisDataToMulticlassLabeledInstances --input_file data/iris.tsv --output_file iris_model/instances --local # - construct the classifier. java -cp target/conjecture-assembly-*.jar com.twitter.scalding.Tool com.etsy.conjecture.demo.LearnMulticlassClassifier --input iris_model/instances --output iris_model --class_names Iris-versicolor,Iris-virginica,Iris-setosa --iters 5 --folds 3 --local ================================================ FILE: bin/model_diff.py ================================================ import json import sys import math if __name__ == '__main__': if len(sys.argv) != 3: sys.exit("Usage: python " + sys.argv[0] + " [model file] [model file]") a = json.load(open(sys.argv[1]))['param']['vector'] b = json.load(open(sys.argv[2]))['param']['vector'] features = set(a.keys()) | set(b.keys()) diff = [] for f in features: dv = a.get(f, 0.0) - b.get(f, 0.0) if math.fabs(dv) > 0.01: diff.append((f, dv, a.get(f), b.get(f))) diff.sort( key = lambda tup: -math.fabs(tup[1])) for t in diff: print t[0] + "\t" + str(t[2]) + "\t" + str(t[3]) + "\t(" + str(t[1]) + ")" ================================================ FILE: bin/model_param.py ================================================ import json import sys import math if __name__ == '__main__': if len(sys.argv) != 2: sys.exit("Usage: python " + sys.argv[0] + " [model file]") vec = json.load(open(sys.argv[1]))['param']['vector'].items() vec.sort(key = lambda tup: -math.fabs(tup[1])) for v in vec: print v[0] + "\t" + str(v[1]) ================================================ FILE: bin/prediction_inspection.py ================================================ import json import sys from optparse import OptionParser from math import floor colors = ["FF0000", "FF1000", "FF2000", "FF3000", "FF4000", "FF5000", "FF6000", "FF7000", "FF8000", "FF9000", "FFA000", "FFB000", "FFC000", "FFD000", "FFE000", "FFF000", "FFFF00", "F0FF00", "E0FF00", "D0FF00", "C0FF00", "B0FF00", "A0FF00", "90FF00", "80FF00", "70FF00", "60FF00", "50FF00", "40FF00", "30FF00", "20FF00", "10FF00"] bins = len(colors) parser = OptionParser(usage="""builds a simple web page providing introspection on predictions made by conjecture models. Depends on the supporting data provided in the instance itself, currently only supporting binary classification problems Usage: %prog [options] """) parser.add_option('-o', '--out', dest='out', default=False, action='store', help="[optional] destination of the generated html. Defaults to standard out") parser.add_option('-f', '--file', dest='file', default=False, action='store', help="[optional] file storing input predictions and instances. Defaults to standard in") parser.add_option('-l', '--label', dest='label', default=False, action='store', help="[optional] only keep examples with this label") parser.add_option('-L', '--limit', dest='limit', default=1000, action='store', help="maximum number of prediction examples to display. Default: 1000") (options, args) = parser.parse_args() output = open(options.out, 'w') if (options.out) else sys.stdout input = open(options.file, 'r') if(options.file) else sys.stdin limit = int(options.limit) output.write("") ct = 0 for line in input: parts = line.strip().split("\t") content = json.loads(parts[0]) label = int(content['label']['value']) pred = float(parts[2]) if (options.label and str(label) != options.label): continue error = min(1.0, abs(pred-label)) bin = bins - int(floor(error*bins)) - 1 color = "#" + colors[bin] out = "" support = json.loads(content['supporting_data']) for key in support.keys(): out = out + "" + key + "
" + support[key] + "
" if (len(out) < 10000 and ct < limit): try: output.write("
"); output.write("%d (%f)
" %( label, pred)) output.write(out) output.write("

") ct = ct + 1 except: pass if (ct >= limit): break output.write(""); output.flush() output.close() ================================================ FILE: build.sbt ================================================ import sbt._ name := "conjecture" version := "0.3.1-SNAPSHOT" organization := "com.etsy" scalaVersion := "2.11.11" crossScalaVersions := Seq("2.11.11", "2.12.4") scalacOptions ++= Seq("-unchecked", "-deprecation") //Because some of our (legal!) java code confuses scaladoc, we must skip it for 2.12 //See: https://github.com/scala/bug/issues/10723 scalacOptions in (Compile, doc) += {if(scalaBinaryVersion.value == "2.12") "-no-java-comments" else ""} javacOptions ++= Seq("-Xlint:none", "-source", "1.7", "-target", "1.7") compileOrder := CompileOrder.JavaThenScala resolvers ++= { Seq( "Concurrent Maven Repo" at "http://conjars.org/repo" ) } libraryDependencies ++= Seq( "cascading" % "cascading-core" % "2.6.1", "cascading" % "cascading-local" % "2.6.1" exclude("com.google.guava", "guava"), "cascading" % "cascading-hadoop" % "2.6.1", "com.google.code.gson" % "gson" % "2.2.2", "com.twitter" %% "algebird-core" % "0.13.0" excludeAll ExclusionRule(organization="org.scala-lang", name="scala-library"), "com.twitter" %% "scalding-core" % "0.17.4" excludeAll ExclusionRule(organization="org.scala-lang", name="scala-library"), "commons-lang" % "commons-lang" % "2.4", "com.joestelmach" % "natty" % "0.7", "io.spray" %% "spray-json" % "1.3.2" excludeAll ExclusionRule(organization="org.scala-lang", name="scala-library"), "com.google.guava" % "guava" % "13.0.1", "org.apache.commons" % "commons-math3" % "3.2", "org.apache.hadoop" % "hadoop-common" % "2.5.0" excludeAll( ExclusionRule(organization="commons-daemon", name="commons-daemon"), ExclusionRule(organization="com.google.guava", name="guava") ), "org.apache.hadoop" % "hadoop-hdfs" % "2.5.0" excludeAll( ExclusionRule(organization="commons-daemon", name="commons-daemon"), ExclusionRule(organization="com.google.guava", name="guava") ), "org.scala-lang" % "scala-reflect" % scalaVersion.value, "net.sf.trove4j" % "trove4j" % "3.0.3", "com.novocode" % "junit-interface" % "0.10" % "test" ) parallelExecution in Test := false publishArtifact in Test := false xerial.sbt.Sonatype.sonatypeSettings publishTo := { if (System.getProperty("release") != null) { publishTo.value } else { val v = version.value val archivaURL = "http://ivy.etsycorp.com/repository" if (v.trim.endsWith("SNAPSHOT")) { Some("publish-snapshots" at (archivaURL + "/snapshots")) } else { Some("publish-releases" at (archivaURL + "/internal")) } } } publishMavenStyle := true overridePublishBothSettings pomIncludeRepository := { x => false } pomExtra := https://github.com/etsy/Conjecture MIT License http://opensource.org/licenses/MIT repo git@github.com:etsy/Conjecture.git scm:git:git@github.com:etsy/Conjecture.git jattenberg Josh Attenberg github.com/jattenberg rjhall Rob Hall github.com/rjhall pomIncludeRepository := { _ => false } // Uncomment if you don't want to run all the tests before building assembly // test in assembly := {} // Janino includes a broken signature, and is not needed: assemblyExcludedJars in assembly <<= (fullClasspath in assembly) map { cp => val excludes = Set("jsp-api-2.1-6.1.14.jar", "jsp-2.1-6.1.14.jar", "jasper-compiler-5.5.12.jar", "janino-2.5.16.jar") cp filter { jar => excludes(jar.data.getName)} } // Some of these files have duplicates, let's ignore: assemblyMergeStrategy in assembly <<= (mergeStrategy in assembly) { (old) => { case s if s.endsWith(".class") => MergeStrategy.last case s if s.endsWith("project.clj") => MergeStrategy.concat case s if s.endsWith(".html") => MergeStrategy.last case s if s.contains("servlet") => MergeStrategy.last case x => old(x) } } ================================================ FILE: clients/phplib/Conjecture/BinaryClassifier.php ================================================ param = $param_vec; } public function dot($instance_vec) { return $this->param->dot($instance_vec); } public function predict($instance_vec) { $dot = $this->dot($instance_vec); $exd = exp($dot); return $exd / (1.0 + $exd); } public function getParams() { return $this->param->getParams(); } public function explain($instance_vec, $n = 10) { $keys = array_intersect_key($this->param->getParams(), $instance_vec->getParams()); $keys = array_map('abs', $keys); arsort($keys); $res = array_slice($keys, 0, (count($keys) < $n ? count($keys) : $n)); foreach ($res as $k => $v) { $res[$k] = "$k(" . round($this->param->getParam($k), 2) . ")"; } return implode(", ", $res); } } ================================================ FILE: clients/phplib/Conjecture/Config.php ================================================ config = $config; } /** * Loads a model local to a user's vm. */ public function getLocalModel($local_file_path) { $model = json_decode($this->parseFile($local_file_path)); $cv = new Conjecture_Vector($model->param->vector); $binary_classifier = new Conjecture_BinaryClassifier($cv); return $binary_classifier; } /** * Decode model json at a given filepath. */ private function parseFile($fp) { if (filesize($fp) > $this->config->getMaxFileSize()) { throw new Conjecture_ConjectureException("model too big: " . $fp . " is " . filesize($fp) . "bytes"); } $res = file($fp); if ($res) { $res = implode("", $res); $res = stripslashes($res); return $res; } else { throw new Conjecture_ConjectureException("model file not found: $fp"); } } private function getLatestModelJsonForProblem($file_name) { if ($this->config->useDummyConjectureModel()) { return self::getDummyModel(); } $fp = $this->config->getConjectureModelPath() . "/" . $file_name; return $this->parseFile($fp); } public function getLatestModelForProblem($file_name) { $json = $this->getLatestModelJsonForProblem($file_name); return json_decode($json); } public function getLatestBinaryClassificationVectorForProblem($file_name) { $model = $this->getLatestModelForProblem($file_name); return new Conjecture_Vector($model->param->vector); } public function getLatestBinaryClassifierForProblem($file_name) { return new Conjecture_BinaryClassifier($this->getLatestBinaryClassificationVectorForProblem($file_name)); } public function getOneVsAllClassifier($file_name) { $model_array = $this->getLatestModelForProblem($file_name); foreach ($model_array as $cat => $params) { $category_params[$cat] = new Conjecture_BinaryClassifier(new Conjecture_Vector($params)); } return new Conjecture_MulticlassOneVsAllClassifier($category_params); } public function getMulticlassClassifier($file_name) { $model_array = $this->getLatestModelForProblem($file_name); $model_type = $model_array->modelType; $category_params = []; foreach ($model_array->param as $cat => $category_model) { $category_params[$cat] = new Conjecture_Vector($category_model->vector); } switch ($model_type) { case "multiclass_logistic_regression": return new Conjecture_MulticlassLogisticRegressionClassifier($category_params); default: return new Conjecture_MulticlassClassifier($category_params); } } static function build(Conjecture_Config $config) { return new Conjecture_Finder($config); } /** * Creates and returns a JSON dummy model with no vectors * used for development settings where "real" JSON models * may not be present */ private static function getDummyModel() { $dummy_model = array("param" => array( "vector" => array(), "modelType" => "dummy", "regularizationWeights" => array(), "epoch" => 1, "period" => 1, "truncationUpdate" => 0, "truncationThreshold" => 0, "initialLearningRate" => .1, "useExponentialLearningRate" => false, "exponentialLearningRate" => 1.0, "examplesPerEpoch" => 1, )); return json_encode($dummy_model); } } ================================================ FILE: clients/phplib/Conjecture/Instance.php ================================================ id; } public function setId($id) { $this->id = $id; return $this; } public function put($key, $value = 1.0) { $this->vector[$key] = $value; } public function update($key, $value = 1.0) { if (array_key_exists($key, $this->vector)) { $this->vector[$key] = $this->vector[$key] + $value; } else { $this->vector[$key] = $value; } return $this; } //some methods to mirror java maps that this class mirrors public function putAll(array $vector) { foreach ($vector as $key => $value) { $this->put($key, $value); } } public function containsKey($key) { return array_key_exists($key, $this->vector); } public function containsValue($key) { return in_array($key, $this->vector); } public function keySet() { return array_keys($this->vector); } public function values() { return array_values($this->vector); } public function size() { return count($this->vector); } public function isEmpty() { return empty($this->vector); } public function remove($key) { unset($this->vector[$key]); } public function toString() { return json_encode($this->vector); } public function addTerm($term, $featureWeight = 1.0, $namespace = "") { $key = $namespace == "" ? $term : $namespace . self::$NAMESPACE_SEP . $term; $this->update($key, $featureWeight); return $this; } public function addTerms(array $terms, $featureWeight = 1.0, $namespace = "") { foreach ($terms as $term) { $this->addTerm($term, $featureWeight, $namespace); } return $this; } public function addNumericArray(array $numberValues, $namespace = "") { for ($i = 0; $i < count($numberValues); $i++ ) { $this->addTerm((string)$i, $numberValues[$i], $namespace); } return $this; } } ================================================ FILE: clients/phplib/Conjecture/MulticlassClassifier.php ================================================ param = $param; } public function predict($instance_vec) { $category_results = []; $total = 0; foreach ($this->param as $category => $classifier) { $prediction = $classifier->dot($instance_vec); $category_results[$category] = $prediction; $total += $prediction; } return array_map( function($prob) use ($total) { return $prob / $total; }, $category_results); } public function getParams() { return $this->param; } public function explain($instance_vec, $n = 10) { $explains = []; foreach ($this->param as $category => $category_model) { $explains[$category] = $this->categoryExplain($instance_vec, $category_model, $n); } return implode(", ", $explains); } private function categoryExplain($instance_vec, $category_model, $n = 10) { $keys = array_intersect_key($category_model->getParams(), $instance_vec->getParams()); $keys = array_map('abs', $keys); arsort($keys); $res = array_slice($keys, 0, (count($keys) < $n ? count($keys) : $n)); foreach ($res as $k => $v) { $res[$k] = "$k(" . round($category_model->getParams($k), 2) . ")"; } return implode(", ", $res); } } ================================================ FILE: clients/phplib/Conjecture/MulticlassLogisticRegressionClassifier.php ================================================ param as $category => $classifier) { $prediction = exp($classifier->dot($instance_vec)); $category_results[$category] = $prediction; $total += $prediction; } return array_map( function($prob) use ($total) { return $prob / $total; }, $category_results); } } ================================================ FILE: clients/phplib/Conjecture/MulticlassOneVsAllClassifier.php ================================================ param = $param; } public function predict($instance_vec) { $category_results = []; $total = 0; foreach ($this->param as $category => $classifier) { $prediction = $classifier->predict($instance_vec); $category_results[$category] = $prediction; $total += $prediction; } return array_map( function($prob) use ($total) { return $prob / $total; }, $category_results); } public function getParams() { $out_params = []; foreach ($this->param as $category => $classifier) { $out_params[$category] = $classifier->getParams(); } return $out_params; } public function explain($instance_vec, $n = 10) { $explains = []; foreach ($this->param as $category => $classifier) { $explains[$category] = $classifier->explain($instance_vec, $n); } return implode(", ", $explains); } } ================================================ FILE: clients/phplib/Conjecture/Text.php ================================================ input = $text; } function toString() { return $this->input; } function replaceNumbers($replacement = "_num_") { $text = preg_replace("/[0-9]+/", $replacement, $this->input); return new Conjecture_Text(preg_replace("/".$replacement."\\s+".$replacement."/", $replacement, $text)); } function replaceHTMLEscapes($replacement = " ") { return new Conjecture_Text(preg_replace("/&[^;]+;/", $replacement, $this->input)); } function removeHTMLTags() { return new Conjecture_Text(preg_replace("/<.*?>/", " ", $this->input)); } function replaceHTMLTags($replacement = " ") { return new Conjecture_Text(preg_replace("/<[^>]+>/", " ", $this->input)); } function replaceNonAlphaNumeric($replacement = " ") { return new Conjecture_Text(preg_replace("/[^a-zA-Z0-9\\.\\s\\-]+/", $replacement, $this->input)); } function replaceNonAlphaNumericUnderscore($replacement = " ") { return new Conjecture_Text(preg_replace("/[^a-zA-Z0-9\\.\\s\\-_]+/", $replacement, $this->input)); } function replaceNonAlpha($replacement = " ") { return new Conjecture_Text(preg_replace("/[^a-zA-Z]+/", $replacement, $this->input)); } function collapseHyphens() { return new Conjecture_Text(preg_replace("/--+/", "--", $this->input)); } function collapseUnderscores() { return new Conjecture_Text(preg_replace("/__+/", "__", $this->input)); } function collapsePeriods() { return new Conjecture_Text(preg_replace("/\.\.+/", "..", $this->input)); } function stripPunctuation() { $temp = preg_replace("^[^A-Za-z0-9]", "", $this->input); return new Conjecture_Text(preg_replace("[^A-Za-z0-9]$", "", $temp)); } // compact any white space function collapse() { return new Conjecture_Text(preg_replace("/\\s+/", " ", $this->input)); } // remove any whitespace from the right of a string function rstrip() { return new Conjecture_Text(preg_replace("/\\s+$/", "", $this->input)); } // remove any whitespace from the left of a string function lstrip() { return new Conjecture_Text(preg_replace("/^\\s+/", "", $this->input)); } // remove any leading or trailing whitespace function strip() { return $this->rstrip()->lstrip(); } // clean up any whitespace function wsclean() { return $this->strip()->collapse(); } // remove any unprintable non-ASCII characters function removeUnprintables() { return new Conjecture_Text(preg_replace("/[^\\x20-\\x7E]/", "", $this->input)); } function collapseWhitespaceAndPunc() { $text = $this->collapse()->collapseHyphens(); return new Conjecture_Text(preg_replace("/\\.\\.+/", ".", $text->toString())); } function toLowerCase() { return new Conjecture_Text(strtolower($this->input)); } function standardTextFilter() { return $this->removeHTMLTags() ->replaceHTMLEscapes() ->replaceNumbers() ->replaceNonAlphaNumericUnderscore() ->collapseHyphens() ->collapseUnderscores() ->wsclean(); } function toArrayFromShingles($n) { $shingles = array(); $chars = str_split($this->input); for ($i = 0; $i < count($chars) - $n + 1; $i++) { $shingle = array_slice($chars, $i, $n); $shingles[] = implode("", $shingle); } return $shingles; } function toSequenceFromShingles($n) { return new Conjecture_TextSequence($this->toArrayFromShingles($n)); } } ================================================ FILE: clients/phplib/Conjecture/TextSequence.php ================================================ tokens = $tokens; } /** * concatenates two TextSequences into an additional text sequence */ function concat($other) { return new Conjecture_TextSequence(array_merge($this->tokens, $other->tokens)); } function mkString($glue = " ") { return implode($glue, $this->tokens); } function toString() { return $this->mkString(" "); } function getTokens() { return $this->tokens; } function filterBlank() { return new Conjecture_TextSequence(array_filter($this->tokens, function($x) { return $x !== ""; } ) ); } function filterStopwords() { return new Conjecture_TextSequence(array_filter($this->tokens, function($x) { return !in_array($x, self::$stopwordList); } ) ); } function stopwords() { return new Conjecture_TextSequence(array_filter($this->tokens, function($x) { return in_array($x, self::$stopwordList); } ) ); } function filterBadwords() { return new Conjecture_TextSequence(array_filter($this->tokens, function($x) { return !in_array($x, self::$badwordList); } ) ); } function badwords() { return new Conjecture_TextSequence(array_filter($this->tokens, function($x) { return in_array($x, self::$badwordList); } ) ); } function filterAllCaps() { return new Conjecture_TextSequence(array_filter($this->tokens, function($x) { return !preg_match('/^[A-Z]+$/', $x); } ) ); } function AllCaps() { return new Conjecture_TextSequence(array_filter($this->tokens, function($x) { return preg_match('/^[A-Z]+$/', $x); } ) ); } function filterCapitalized() { return new Conjecture_TextSequence(array_filter($this->tokens, function($x) { return !preg_match('/^[A-Z][^A-Z]+$/', $x); } ) ); } function capitalized() { return new Conjecture_TextSequence(array_filter($this->tokens, function($x) { return preg_match('/^[A-Z][^A-Z]+$/', $x); } ) ); } function filterLowercase() { return new Conjecture_TextSequence(array_filter($this->tokens, function($x) { return !preg_match('/^[a-z]+$/', $x); } ) ); } function allLowercase() { return new Conjecture_TextSequence(array_filter($this->tokens, function($x) { return preg_match('/^[a-z]+$/', $x); } ) ); } function filterURLs() { return new Conjecture_TextSequence(array_filter($this->tokens, function($x) { return !preg_match('/^https?://.+/', $x); } ) ); } function allURLs() { return new Conjecture_TextSequence(array_filter($this->tokens, function($x) { return preg_match('/^https?://.+/', $x); } ) ); } function filterListings() { return new Conjecture_TextSequence(array_filter($this->tokens, function($x) { return !preg_match('/^https?://.+etsy.+/listing/[0-9]+.*/', $x); } ) ); } function allListings() { return new Conjecture_TextSequence(array_filter($this->tokens, function($x) { return preg_match('/^https?://.+etsy.+/listing/[0-9]+.*/', $x); } ) ); } function size() { return count($this->tokens); } function stopWordCount() { return $this->stopwords()->size(); } function stopWordFraq($bins = 10.0) { return floor(round($bins*$this->stopWordCount()/$this->size())/$bins); } function badWordCount() { return $this->badwords()->size(); } function badWordFraq($bins = 10.0) { return floor(round($bins*$this->badWordCount()/$this->size())/$bins); } function capsCount() { return $this->allCaps()->size(); } function capFraq($bins = 10.0) { return floor(round($bins*$this->capsCount()/$this->size())/$bins); } function urlCount() { return $this->allURLs()->size(); } function urlFraq($bins = 10.0) { return floor(round($bins*$this->urlCount()/$this->size())/$bins); } function listingsCount() { return $this->badwords()->size(); } function listingsFraq($bins = 10.0) { return floor(round($bins*$this->allListings()/$this->size())/$bins); } function sizeBin() { return floor(log($this->size())); } // filtering methods (TODO) function replaceNumbers($replacement = "_num_") { return new Conjecture_TextSequence(array_map( function($x) use ($replacement) { $text = preg_replace("/[0-9]+/", $replacement, $x); return preg_replace("/".$replacement."\\s+".$replacement."/", $replacement, $text); }, $this->tokens)); } function replaceHTMLEscapes($replacement = " ") { return new Conjecture_TextSequence(array_map( function($x) use ($replacement) { return preg_replace("/&[^;]+;/", $replacement, $x); }, $this->tokens)); } function removeHTMLTags() { return $this->replaceHTMLTags(" "); } function replaceHTMLTags($replacement = " ") { return new Conjecture_TextSequence(array_map( function($x) use ($replacement) { return preg_replace("/<[^>]+>/", $replacement, $x); }, $this->tokens)); } function replaceNonAlphaNumeric($replacement = " ") { return new Conjecture_TextSequence(array_map( function($x) use ($replacement) { return preg_replace("/[^a-zA-Z0-9\\.\\s\\-]+/", $replacement, $x); }, $this->tokens)); } function replaceNonAlphaNumericUnderscore($replacement = " ") { return new Conjecture_TextSequence(array_map( function($x) use ($replacement) { return preg_replace("/[^a-zA-Z0-9\\.\\s\\-_]+/", $replacement, $x); }, $this->tokens)); } function replaceNonAlpha($replacement = " ") { return new Conjecture_TextSequence(array_map( function($x) use ($replacement) { return preg_replace("/[^a-zA-Z\\.\\s\\-_]+/", $replacement, $x); }, $this->tokens)); } function collapseHyphens() { return new Conjecture_TextSequence(array_map( function($x) { return preg_replace("/--+/", "--", $x); }, $this->tokens)); } function collapseUnderscores() { return new Conjecture_TextSequence(array_map( function($x) { return preg_replace("/__+/", "__", $x); }, $this->tokens)); } function collapsePeriods() { return new Conjecture_TextSequence(array_map( function($x) { return preg_replace("/\.\.+/", "..", $x); }, $this->tokens)); } function stripPunctuation() { return new Conjecture_TextSequence(array_map( function($x) { $temp = preg_replace("^[^A-Za-z0-9]", "", $x); return preg_replace("[^A-Za-z0-9]$", "", $temp); }, $this->tokens)); } // compact any white space function collapse() { return new Conjecture_TextSequence(array_map( function($x) { return preg_replace("/\\s+/", " ", $x); }, $this->tokens)); } function rstrip() { return new Conjecture_TextSequence(array_map( function($x) { return preg_replace("/^\\s+/", "", $x); }, $this->tokens)); } function lstrip() { return new Conjecture_TextSequence(array_map( function($x) { return preg_replace("/\\s+$/", "", $x); }, $this->tokens)); } // remove any leading or trailing whitespace function strip() { return $this->rstrip()->lstrip(); } // clean up any whitespace function wsclean() { return $this->strip()->collapse(); } // remove any unprintable non-ASCII characters function removeUnprintables() { return new Conjecture_TextSequence(array_map( function($x) { return preg_replace("/[^\\x20-\\x7E]/", "", $x); }, $this->tokens)); } function collapseWhitespaceAndPunc() { return new Conjecture_TextSequence(array_map( function($x) { $ws = preg_replace("/\\s+/", " ", $x); $dh = preg_replace("/[\\-]+/", "-", $ws); return preg_replace("/[\\.]+/", ".", $dh); }, $this->tokens)); } function prependNameSpace($namespace) { return new Conjecture_TextSequence(array_map( function($x) use ($namespace) { return $namespace . $x; }, $this->tokens)); } function toList() { return $this->tokens; } function shingles($n, $whitespace = "_") { $str = implode($whitespace, $this->tokens); $arr = explode('', $str); $shingles = array(); for ($i = 0; $i < count($arr) - $n; $i++) { $shingles[] = implode('', array_slice($arr, $i, $i + $n)); } return new Conjecture_TextSequence($shingles); } function ngrams($n, $glue = " ") { $grams = array(); for ($i = 0; $i < count($this->tokens) - $n+1; $i++) { $grams[] = implode($glue, array_slice($this->tokens, $i, $n)); } return new Conjecture_TextSequence($grams); } function unigramsAndBigrams($glue = " ") { return $this->ngrams(1)->concat($this->ngrams(2, $glue)); } function toInstance() { $instance = new Conjecture_Instance(); foreach ($this->tokens as $token) { $instance->addTerm($token); } return $instance; } static $stopwordList = array("a","as","able","about","above","according","accordingly","across","actually","after","afterwards","again","against","aint","all","allow","allows","almost","alone","along","already","also","although","always","am","among","amongst","amoungst","amount","an","and","another","any","anybody","anyhow","anyone","anything","anyway","anyways","anywhere","apart","appear","appreciate","appropriate","are","arent","around","as","aside","ask","asking","associated","at","available","away","awfully","b","back","be","became","because","become","becomes","becoming","been","before","beforehand","behind","being","believe","below","beside","besides","best","better","between","beyond","bill","both","bottom","brief","but","by","c","cmon","cs","call","came","can","cant","cannot","cant","cause","causes","certain","certainly","changes","clearly","co","com","come","comes","con","concerning","consequently","consider","considering","contain","containing","contains","corresponding","could","couldnt","couldnt","course","cry","currently","d","de","definitely","describe","described","despite","detail","did","didnt","different","do","does","doesnt","doing","dont","done","down","downwards","due","during","e","each","edu","eg","eight","either","eleven","else","elsewhere","empty","enough","entirely","especially","et","etc","even","ever","every","everybody","everyone","everything","everywhere","ex","exactly","example","except","f","far","few","fifteen","fifth","fify","fill","find","fire","first","five","followed","following","follows","for","former","formerly","forth","forty","found","four","from","front","full","further","furthermore","g","get","gets","getting","give","given","gives","go","goes","going","gone","got","gotten","greetings","h","had","hadnt","happens","hardly","has","hasnt","hasnt","have","havent","having","he","hes","hello","help","hence","her","here","heres","hereafter","hereby","herein","hereupon","hers","herself","hi","him","himself","his","hither","hopefully","how","howbeit","however","hundred","i","id","ill","im","ive","ie","if","ignored","immediate","in","inasmuch","inc","indeed","indicate","indicated","indicates","inner","insofar","instead","interest","into","inward","is","isnt","it","itd","itll","its","its","itself","j","just","k","keep","keeps","kept","know","known","knows","l","last","lately","later","latter","latterly","least","less","lest","let","lets","like","liked","likely","little","look","looking","looks","ltd","m","made","mainly","many","may","maybe","me","mean","meanwhile","merely","might","mill","mine","more","moreover","most","mostly","move","much","must","my","myself","n","name","namely","nd","near","nearly","necessary","need","needs","neither","never","nevertheless","new","next","nine","no","nobody","non","none","noone","nor","normally","not","nothing","novel","now","nowhere","o","obviously","of","off","often","oh","ok","okay","old","on","once","one","ones","only","onto","or","other","others","otherwise","ought","our","ours","ourselves","out","outside","over","overall","own","p","part","particular","particularly","per","perhaps","placed","please","plus","possible","presumably","probably","provides","put","q","que","quite","qv","r","rather","rd","re","really","reasonably","regarding","regardless","regards","relatively","respectively","right","s","said","same","saw","say","saying","says","second","secondly","see","seeing","seem","seemed","seeming","seems","seen","self","selves","sensible","sent","serious","seriously","seven","several","shall","she","should","shouldnt","show","side","since","sincere","six","sixty","so","some","somebody","somehow","someone","something","sometime","sometimes","somewhat","somewhere","soon","sorry","specified","specify","specifying","still","sub","such","sup","sure","system","t","ts","take","taken","tell","ten","tends","th","than","thank","thanks","thanx","that","thats","thats","the","thea","their","theirs","them","themselves","then","thence","there","theres","thereafter","thereby","therefore","therein","theres","thereupon","these","they","theyd","theyll","theyre","theyve","thickv","thin","think","third","this","thorough","thoroughly","those","though","three","through","throughout","thru","thus","to","together","too","took","top","toward","towards","tried","tries","truly","try","trying","twelve","twenty","twice","two","u","un","under","unfortunately","unless","unlikely","until","unto","up","re","werent","what","whats","whatever","when","whence","whenever","where","wheres","whereafter","whereas","whereby","wherein","whereupon","wherever","whether","which","while","whither","who","whos","whoever","whole","whom","whose","why","will","willing","wish","with","within","without","wont","wonder","would","wouldnt","x","y","yes","yet","you","youd","youll","youre","youve","your","yours","yourself","yourselves","z","zero"); static $badwordList = array("ahole", "arse", "ass", "asshole", "asswipe", "bastard", "batty", "bender", "bitch", "bloody", "bollocks", "boner", "bumboy", "bugger", "coon", "cock", "cocksucker", "cracker", "crap", "cumsucker", "cunt", "damn", "dick", "dildo", "douchebag", "faggot", "fistfucker", "fuck", "fucker", "fuckwit", "fucktwat", "gaylord", "ho", "honky", "jackass", "jism", "joey", "knobcheese", "minge", "minger", "mong", "motherfucker", "munter", "pickle", "piss", "piss", "prick", "pussy", "rimmer", "schmuck", "shit", "slut", "spakka", "spaz", "skank", "taint", "tit", "tool", "tosser", "twat", "whore", "wanker"); } ================================================ FILE: clients/phplib/Conjecture/Vector.php ================================================ vector = (array)$array; } public function dot($rhs) { $keys = array_intersect_key($this->vector, $rhs->vector); $res = 0.0; foreach ($keys as $key => $val) { $res += $this->vector[$key] * $rhs->vector[$key]; } return $res; } public function getParams() { return $this->vector; } public function getParam($k) { if (array_key_exists($k, $this->vector)) { return $this->vector[$k]; } else { return 0.0; } } } ================================================ FILE: data/iris.tsv ================================================ 7.0 3.2 4.7 1.4 Iris-versicolor 5.6 3.0 4.1 1.3 Iris-versicolor 5.4 3.4 1.7 0.2 Iris-setosa 5.0 3.0 1.6 0.2 Iris-setosa 6.9 3.2 5.7 2.3 Iris-virginica 4.9 3.0 1.4 0.2 Iris-setosa 5.0 2.3 3.3 1.0 Iris-versicolor 5.2 2.7 3.9 1.4 Iris-versicolor 5.1 3.8 1.9 0.4 Iris-setosa 7.2 3.6 6.1 2.5 Iris-virginica 4.8 3.4 1.6 0.2 Iris-setosa 6.0 2.9 4.5 1.5 Iris-versicolor 5.8 2.6 4.0 1.2 Iris-versicolor 5.7 2.6 3.5 1.0 Iris-versicolor 5.9 3.0 4.2 1.5 Iris-versicolor 5.5 2.3 4.0 1.3 Iris-versicolor 4.6 3.2 1.4 0.2 Iris-setosa 6.3 2.8 5.1 1.5 Iris-virginica 6.3 3.3 6.0 2.5 Iris-virginica 6.9 3.1 4.9 1.5 Iris-versicolor 6.7 3.3 5.7 2.5 Iris-virginica 5.1 3.7 1.5 0.4 Iris-setosa 6.7 3.3 5.7 2.1 Iris-virginica 5.8 2.8 5.1 2.4 Iris-virginica 6.0 3.4 4.5 1.6 Iris-versicolor 5.4 3.0 4.5 1.5 Iris-versicolor 5.5 3.5 1.3 0.2 Iris-setosa 5.0 3.3 1.4 0.2 Iris-setosa 5.7 4.4 1.5 0.4 Iris-setosa 5.3 3.7 1.5 0.2 Iris-setosa 5.2 3.5 1.5 0.2 Iris-setosa 6.5 2.8 4.6 1.5 Iris-versicolor 7.4 2.8 6.1 1.9 Iris-virginica 4.9 3.1 1.5 0.2 Iris-setosa 5.0 3.2 1.2 0.2 Iris-setosa 7.7 2.8 6.7 2.0 Iris-virginica 4.8 3.4 1.9 0.2 Iris-setosa 6.5 3.0 5.2 2.0 Iris-virginica 6.3 2.5 5.0 1.9 Iris-virginica 6.4 3.1 5.5 1.8 Iris-virginica 5.8 2.7 5.1 1.9 Iris-virginica 7.1 3.0 5.9 2.1 Iris-virginica 5.7 2.5 5.0 2.0 Iris-virginica 6.4 2.8 5.6 2.2 Iris-virginica 6.4 3.2 4.5 1.5 Iris-versicolor 6.1 2.6 5.6 1.4 Iris-virginica 4.8 3.0 1.4 0.1 Iris-setosa 5.6 2.8 4.9 2.0 Iris-virginica 6.0 2.2 5.0 1.5 Iris-virginica 5.0 3.5 1.3 0.3 Iris-setosa 5.5 2.6 4.4 1.2 Iris-versicolor 5.0 3.6 1.4 0.2 Iris-setosa 5.0 3.4 1.6 0.4 Iris-setosa 6.3 2.7 4.9 1.8 Iris-virginica 6.7 3.1 4.7 1.5 Iris-versicolor 6.3 2.5 4.9 1.5 Iris-versicolor 4.5 2.3 1.3 0.3 Iris-setosa 6.8 3.2 5.9 2.3 Iris-virginica 7.2 3.2 6.0 1.8 Iris-virginica 5.5 2.4 3.8 1.1 Iris-versicolor 5.8 2.7 5.1 1.9 Iris-virginica 6.1 2.8 4.0 1.3 Iris-versicolor 6.3 2.9 5.6 1.8 Iris-virginica 6.1 2.9 4.7 1.4 Iris-versicolor 6.3 2.3 4.4 1.3 Iris-versicolor 4.6 3.4 1.4 0.3 Iris-setosa 5.5 4.2 1.4 0.2 Iris-setosa 6.5 3.0 5.5 1.8 Iris-virginica 6.7 3.1 4.4 1.4 Iris-versicolor 6.6 2.9 4.6 1.3 Iris-versicolor 5.9 3.0 5.1 1.8 Iris-virginica 6.4 2.7 5.3 1.9 Iris-virginica 5.6 2.5 3.9 1.1 Iris-versicolor 6.4 3.2 5.3 2.3 Iris-virginica 5.7 3.8 1.7 0.3 Iris-setosa 7.2 3.0 5.8 1.6 Iris-virginica 6.7 3.0 5.2 2.3 Iris-virginica 4.6 3.1 1.5 0.2 Iris-setosa 5.6 2.9 3.6 1.3 Iris-versicolor 6.4 2.9 4.3 1.3 Iris-versicolor 5.1 3.5 1.4 0.2 Iris-setosa 7.6 3.0 6.6 2.1 Iris-virginica 5.7 2.8 4.1 1.3 Iris-versicolor 5.6 2.7 4.2 1.3 Iris-versicolor 5.7 2.9 4.2 1.3 Iris-versicolor 5.4 3.7 1.5 0.2 Iris-setosa 6.4 2.8 5.6 2.1 Iris-virginica 4.6 3.6 1.0 0.2 Iris-setosa 4.4 2.9 1.4 0.2 Iris-setosa 4.4 3.2 1.3 0.2 Iris-setosa 6.2 3.4 5.4 2.3 Iris-virginica 6.3 3.4 5.6 2.4 Iris-virginica 6.8 2.8 4.8 1.4 Iris-versicolor 5.1 3.4 1.5 0.2 Iris-setosa 6.1 3.0 4.9 1.8 Iris-virginica 5.7 3.0 4.2 1.2 Iris-versicolor 5.0 3.4 1.5 0.2 Iris-setosa 5.0 3.5 1.6 0.6 Iris-setosa 7.7 3.8 6.7 2.2 Iris-virginica 4.9 3.1 1.5 0.1 Iris-setosa 6.0 2.2 4.0 1.0 Iris-versicolor 6.8 3.0 5.5 2.1 Iris-virginica 5.1 2.5 3.0 1.1 Iris-versicolor 6.5 3.2 5.1 2.0 Iris-virginica 4.7 3.2 1.3 0.2 Iris-setosa 6.6 3.0 4.4 1.4 Iris-versicolor 6.7 3.0 5.0 1.7 Iris-versicolor 4.8 3.0 1.4 0.3 Iris-setosa 5.1 3.8 1.5 0.3 Iris-setosa 7.7 2.6 6.9 2.3 Iris-virginica 5.1 3.8 1.6 0.2 Iris-setosa 5.0 2.0 3.5 1.0 Iris-versicolor 7.7 3.0 6.1 2.3 Iris-virginica 6.5 3.0 5.8 2.2 Iris-virginica 5.8 4.0 1.2 0.2 Iris-setosa 5.4 3.4 1.5 0.4 Iris-setosa 6.2 2.2 4.5 1.5 Iris-versicolor 5.7 2.8 4.5 1.3 Iris-versicolor 5.5 2.5 4.0 1.3 Iris-versicolor 7.3 2.9 6.3 1.8 Iris-virginica 5.6 3.0 4.5 1.5 Iris-versicolor 6.2 2.8 4.8 1.8 Iris-virginica 4.3 3.0 1.1 0.1 Iris-setosa 5.8 2.7 3.9 1.2 Iris-versicolor 7.9 3.8 6.4 2.0 Iris-virginica 6.2 2.9 4.3 1.3 Iris-versicolor 4.9 2.5 4.5 1.7 Iris-virginica 4.9 3.6 1.4 0.1 Iris-setosa 5.2 3.4 1.4 0.2 Iris-setosa 6.0 2.7 5.1 1.6 Iris-versicolor 6.9 3.1 5.4 2.1 Iris-virginica 4.8 3.1 1.6 0.2 Iris-setosa 6.7 3.1 5.6 2.4 Iris-virginica 6.3 3.3 4.7 1.6 Iris-versicolor 5.2 4.1 1.5 0.1 Iris-setosa 5.4 3.9 1.3 0.4 Iris-setosa 4.9 2.4 3.3 1.0 Iris-versicolor 5.5 2.4 3.7 1.0 Iris-versicolor 5.1 3.5 1.4 0.3 Iris-setosa 6.1 3.0 4.6 1.4 Iris-versicolor 5.1 3.3 1.7 0.5 Iris-setosa 4.4 3.0 1.3 0.2 Iris-setosa 5.9 3.2 4.8 1.8 Iris-versicolor 4.7 3.2 1.6 0.2 Iris-setosa 6.9 3.1 5.1 2.3 Iris-virginica 5.4 3.9 1.7 0.4 Iris-setosa 5.8 2.7 4.1 1.0 Iris-versicolor 6.1 2.8 4.7 1.2 Iris-versicolor 6.0 3.0 4.8 1.8 Iris-virginica 6.7 2.5 5.8 1.8 Iris-virginica ================================================ FILE: project/build.properties ================================================ sbt.version=0.13.9 ================================================ FILE: project/plugins.sbt ================================================ addSbtPlugin("com.eed3si9n" % "sbt-assembly" % "0.13.0") addSbtPlugin("no.arktekk.sbt" % "aether-deploy" % "0.14") addSbtPlugin("org.xerial.sbt" % "sbt-sonatype" % "0.2.1") addSbtPlugin("com.typesafe.sbt" % "sbt-pgp" % "0.8.3") ================================================ FILE: sbt ================================================ #!/usr/bin/env bash # # A more capable sbt runner, coincidentally also called sbt. # Author: Paul Phillips # todo - make this dynamic declare -r sbt_release_version="0.13.8" declare -r sbt_unreleased_version="0.13.9-M1" declare -r buildProps="project/build.properties" declare sbt_jar sbt_dir sbt_create sbt_version declare scala_version sbt_explicit_version declare verbose noshare batch trace_level log_level declare sbt_saved_stty debugUs echoerr () { echo >&2 "$@"; } vlog () { [[ -n "$verbose" ]] && echoerr "$@"; } # spaces are possible, e.g. sbt.version = 0.13.0 build_props_sbt () { [[ -r "$buildProps" ]] && \ grep '^sbt\.version' "$buildProps" | tr '=\r' ' ' | awk '{ print $2; }' } update_build_props_sbt () { local ver="$1" local old="$(build_props_sbt)" [[ -r "$buildProps" ]] && [[ "$ver" != "$old" ]] && { perl -pi -e "s/^sbt\.version\b.*\$/sbt.version=${ver}/" "$buildProps" grep -q '^sbt.version[ =]' "$buildProps" || printf "\nsbt.version=%s\n" "$ver" >> "$buildProps" vlog "!!!" vlog "!!! Updated file $buildProps setting sbt.version to: $ver" vlog "!!! Previous value was: $old" vlog "!!!" } } set_sbt_version () { sbt_version="${sbt_explicit_version:-$(build_props_sbt)}" [[ -n "$sbt_version" ]] || sbt_version=$sbt_release_version export sbt_version } # restore stty settings (echo in particular) onSbtRunnerExit() { [[ -n "$sbt_saved_stty" ]] || return vlog "" vlog "restoring stty: $sbt_saved_stty" stty "$sbt_saved_stty" unset sbt_saved_stty } # save stty and trap exit, to ensure echo is reenabled if we are interrupted. trap onSbtRunnerExit EXIT sbt_saved_stty="$(stty -g 2>/dev/null)" vlog "Saved stty: $sbt_saved_stty" # this seems to cover the bases on OSX, and someone will # have to tell me about the others. get_script_path () { local path="$1" [[ -L "$path" ]] || { echo "$path" ; return; } local target="$(readlink "$path")" if [[ "${target:0:1}" == "/" ]]; then echo "$target" else echo "${path%/*}/$target" fi } die() { echo "Aborting: $@" exit 1 } make_url () { version="$1" case "$version" in 0.7.*) echo "http://simple-build-tool.googlecode.com/files/sbt-launch-0.7.7.jar" ;; 0.10.* ) echo "$sbt_launch_repo/org.scala-tools.sbt/sbt-launch/$version/sbt-launch.jar" ;; 0.11.[12]) echo "$sbt_launch_repo/org.scala-tools.sbt/sbt-launch/$version/sbt-launch.jar" ;; *) echo "$sbt_launch_repo/org.scala-sbt/sbt-launch/$version/sbt-launch.jar" ;; esac } init_default_option_file () { local overriding_var="${!1}" local default_file="$2" if [[ ! -r "$default_file" && "$overriding_var" =~ ^@(.*)$ ]]; then local envvar_file="${BASH_REMATCH[1]}" if [[ -r "$envvar_file" ]]; then default_file="$envvar_file" fi fi echo "$default_file" } declare -r cms_opts="-XX:+CMSClassUnloadingEnabled -XX:+UseConcMarkSweepGC" declare -r jit_opts="-XX:ReservedCodeCacheSize=256m -XX:+TieredCompilation" declare -r default_jvm_opts_common="-Xms512m -Xmx1536m -Xss2m $jit_opts $cms_opts" declare -r noshare_opts="-Dsbt.global.base=project/.sbtboot -Dsbt.boot.directory=project/.boot -Dsbt.ivy.home=project/.ivy" declare -r latest_28="2.8.2" declare -r latest_29="2.9.3" declare -r latest_210="2.10.5" declare -r latest_211="2.11.7" declare -r script_path="$(get_script_path "$BASH_SOURCE")" declare -r script_name="${script_path##*/}" # some non-read-onlies set with defaults declare java_cmd="java" declare sbt_opts_file="$(init_default_option_file SBT_OPTS .sbtopts)" declare jvm_opts_file="$(init_default_option_file JVM_OPTS .jvmopts)" declare sbt_launch_repo="http://repo.typesafe.com/typesafe/ivy-releases" # pull -J and -D options to give to java. declare -a residual_args declare -a java_args declare -a scalac_args declare -a sbt_commands # args to jvm/sbt via files or environment variables declare -a extra_jvm_opts extra_sbt_opts addJava () { vlog "[addJava] arg = '$1'" java_args+=("$1") } addSbt () { vlog "[addSbt] arg = '$1'" sbt_commands+=("$1") } setThisBuild () { vlog "[addBuild] args = '$@'" local key="$1" && shift addSbt "set $key in ThisBuild := $@" } addScalac () { vlog "[addScalac] arg = '$1'" scalac_args+=("$1") } addResidual () { vlog "[residual] arg = '$1'" residual_args+=("$1") } addResolver () { addSbt "set resolvers += $1" } addDebugger () { addJava "-Xdebug" addJava "-Xrunjdwp:transport=dt_socket,server=y,suspend=n,address=$1" } setScalaVersion () { [[ "$1" == *"-SNAPSHOT" ]] && addResolver 'Resolver.sonatypeRepo("snapshots")' addSbt "++ $1" } setJavaHome () { java_cmd="$1/bin/java" setThisBuild javaHome "Some(file(\"$1\"))" export JAVA_HOME="$1" export JDK_HOME="$1" export PATH="$JAVA_HOME/bin:$PATH" } setJavaHomeQuietly () { addSbt warn setJavaHome "$1" addSbt info } # if set, use JDK_HOME/JAVA_HOME over java found in path if [[ -e "$JDK_HOME/lib/tools.jar" ]]; then setJavaHomeQuietly "$JDK_HOME" elif [[ -e "$JAVA_HOME/bin/java" ]]; then setJavaHomeQuietly "$JAVA_HOME" fi # directory to store sbt launchers declare sbt_launch_dir="$HOME/.sbt/launchers" [[ -d "$sbt_launch_dir" ]] || mkdir -p "$sbt_launch_dir" [[ -w "$sbt_launch_dir" ]] || sbt_launch_dir="$(mktemp -d -t sbt_extras_launchers.XXXXXX)" java_version () { local version=$("$java_cmd" -version 2>&1 | grep -E -e '(java|openjdk) version' | awk '{ print $3 }' | tr -d \") vlog "Detected Java version: $version" echo "${version:2:1}" } # MaxPermSize critical on pre-8 jvms but incurs noisy warning on 8+ default_jvm_opts () { local v="$(java_version)" if [[ $v -ge 8 ]]; then echo "$default_jvm_opts_common" else echo "-XX:MaxPermSize=384m $default_jvm_opts_common" fi } build_props_scala () { if [[ -r "$buildProps" ]]; then versionLine="$(grep '^build.scala.versions' "$buildProps")" versionString="${versionLine##build.scala.versions=}" echo "${versionString%% .*}" fi } execRunner () { # print the arguments one to a line, quoting any containing spaces vlog "# Executing command line:" && { for arg; do if [[ -n "$arg" ]]; then if printf "%s\n" "$arg" | grep -q ' '; then printf >&2 "\"%s\"\n" "$arg" else printf >&2 "%s\n" "$arg" fi fi done vlog "" } [[ -n "$batch" ]] && exec /dev/null; then curl --fail --silent --location "$url" --output "$jar" elif which wget >/dev/null; then wget --quiet -O "$jar" "$url" fi } && [[ -r "$jar" ]] } acquire_sbt_jar () { sbt_url="$(jar_url "$sbt_version")" sbt_jar="$(jar_file "$sbt_version")" [[ -r "$sbt_jar" ]] || download_url "$sbt_url" "$sbt_jar" } usage () { cat < display stack traces with a max of frames (default: -1, traces suppressed) -debug-inc enable debugging log for the incremental compiler -no-colors disable ANSI color codes -sbt-create start sbt even if current directory contains no sbt project -sbt-dir path to global settings/plugins directory (default: ~/.sbt/) -sbt-boot path to shared boot directory (default: ~/.sbt/boot in 0.11+) -ivy path to local Ivy repository (default: ~/.ivy2) -no-share use all local caches; no sharing -offline put sbt in offline mode -jvm-debug Turn on JVM debugging, open at the given port. -batch Disable interactive mode -prompt Set the sbt prompt; in expr, 's' is the State and 'e' is Extracted # sbt version (default: sbt.version from $buildProps if present, otherwise $sbt_release_version) -sbt-force-latest force the use of the latest release of sbt: $sbt_release_version -sbt-version use the specified version of sbt (default: $sbt_release_version) -sbt-dev use the latest pre-release version of sbt: $sbt_unreleased_version -sbt-jar use the specified jar as the sbt launcher -sbt-launch-dir directory to hold sbt launchers (default: ~/.sbt/launchers) -sbt-launch-repo repo url for downloading sbt launcher jar (default: $sbt_launch_repo) # scala version (default: as chosen by sbt) -28 use $latest_28 -29 use $latest_29 -210 use $latest_210 -211 use $latest_211 -scala-home use the scala build at the specified directory -scala-version use the specified version of scala -binary-version use the specified scala version when searching for dependencies # java version (default: java from PATH, currently $(java -version 2>&1 | grep version)) -java-home alternate JAVA_HOME # passing options to the jvm - note it does NOT use JAVA_OPTS due to pollution # The default set is used if JVM_OPTS is unset and no -jvm-opts file is found $(default_jvm_opts) JVM_OPTS environment variable holding either the jvm args directly, or the reference to a file containing jvm args if given path is prepended by '@' (e.g. '@/etc/jvmopts') Note: "@"-file is overridden by local '.jvmopts' or '-jvm-opts' argument. -jvm-opts file containing jvm args (if not given, .jvmopts in project root is used if present) -Dkey=val pass -Dkey=val directly to the jvm -J-X pass option -X directly to the jvm (-J is stripped) # passing options to sbt, OR to this runner SBT_OPTS environment variable holding either the sbt args directly, or the reference to a file containing sbt args if given path is prepended by '@' (e.g. '@/etc/sbtopts') Note: "@"-file is overridden by local '.sbtopts' or '-sbt-opts' argument. -sbt-opts file containing sbt args (if not given, .sbtopts in project root is used if present) -S-X add -X to sbt's scalacOptions (-S is stripped) EOM } process_args () { require_arg () { local type="$1" local opt="$2" local arg="$3" if [[ -z "$arg" ]] || [[ "${arg:0:1}" == "-" ]]; then die "$opt requires <$type> argument" fi } while [[ $# -gt 0 ]]; do case "$1" in -h|-help) usage; exit 1 ;; -v) verbose=true && shift ;; -d) addSbt "--debug" && addSbt debug && shift ;; -w) addSbt "--warn" && addSbt warn && shift ;; -q) addSbt "--error" && addSbt error && shift ;; -x) debugUs=true && shift ;; -trace) require_arg integer "$1" "$2" && trace_level="$2" && shift 2 ;; -ivy) require_arg path "$1" "$2" && addJava "-Dsbt.ivy.home=$2" && shift 2 ;; -no-colors) addJava "-Dsbt.log.noformat=true" && shift ;; -no-share) noshare=true && shift ;; -sbt-boot) require_arg path "$1" "$2" && addJava "-Dsbt.boot.directory=$2" && shift 2 ;; -sbt-dir) require_arg path "$1" "$2" && sbt_dir="$2" && shift 2 ;; -debug-inc) addJava "-Dxsbt.inc.debug=true" && shift ;; -offline) addSbt "set offline := true" && shift ;; -jvm-debug) require_arg port "$1" "$2" && addDebugger "$2" && shift 2 ;; -batch) batch=true && shift ;; -prompt) require_arg "expr" "$1" "$2" && setThisBuild shellPrompt "(s => { val e = Project.extract(s) ; $2 })" && shift 2 ;; -sbt-create) sbt_create=true && shift ;; -sbt-jar) require_arg path "$1" "$2" && sbt_jar="$2" && shift 2 ;; -sbt-version) require_arg version "$1" "$2" && sbt_explicit_version="$2" && shift 2 ;; -sbt-force-latest) sbt_explicit_version="$sbt_release_version" && shift ;; -sbt-dev) sbt_explicit_version="$sbt_unreleased_version" && shift ;; -sbt-launch-dir) require_arg path "$1" "$2" && sbt_launch_dir="$2" && shift 2 ;; -sbt-launch-repo) require_arg path "$1" "$2" && sbt_launch_repo="$2" && shift 2 ;; -scala-version) require_arg version "$1" "$2" && setScalaVersion "$2" && shift 2 ;; -binary-version) require_arg version "$1" "$2" && setThisBuild scalaBinaryVersion "\"$2\"" && shift 2 ;; -scala-home) require_arg path "$1" "$2" && setThisBuild scalaHome "Some(file(\"$2\"))" && shift 2 ;; -java-home) require_arg path "$1" "$2" && setJavaHome "$2" && shift 2 ;; -sbt-opts) require_arg path "$1" "$2" && sbt_opts_file="$2" && shift 2 ;; -jvm-opts) require_arg path "$1" "$2" && jvm_opts_file="$2" && shift 2 ;; -D*) addJava "$1" && shift ;; -J*) addJava "${1:2}" && shift ;; -S*) addScalac "${1:2}" && shift ;; -28) setScalaVersion "$latest_28" && shift ;; -29) setScalaVersion "$latest_29" && shift ;; -210) setScalaVersion "$latest_210" && shift ;; -211) setScalaVersion "$latest_211" && shift ;; --debug) addSbt debug && addResidual "$1" && shift ;; --warn) addSbt warn && addResidual "$1" && shift ;; --error) addSbt error && addResidual "$1" && shift ;; *) addResidual "$1" && shift ;; esac done } # process the direct command line arguments process_args "$@" # skip #-styled comments and blank lines readConfigFile() { while read line; do [[ $line =~ ^# ]] || [[ -z $line ]] || echo "$line" done < "$1" } # if there are file/environment sbt_opts, process again so we # can supply args to this runner if [[ -r "$sbt_opts_file" ]]; then vlog "Using sbt options defined in file $sbt_opts_file" while read opt; do extra_sbt_opts+=("$opt"); done < <(readConfigFile "$sbt_opts_file") elif [[ -n "$SBT_OPTS" && ! ("$SBT_OPTS" =~ ^@.*) ]]; then vlog "Using sbt options defined in variable \$SBT_OPTS" extra_sbt_opts=( $SBT_OPTS ) else vlog "No extra sbt options have been defined" fi [[ -n "${extra_sbt_opts[*]}" ]] && process_args "${extra_sbt_opts[@]}" # reset "$@" to the residual args set -- "${residual_args[@]}" argumentCount=$# # set sbt version set_sbt_version # only exists in 0.12+ setTraceLevel() { case "$sbt_version" in "0.7."* | "0.10."* | "0.11."* ) echoerr "Cannot set trace level in sbt version $sbt_version" ;; *) setThisBuild traceLevel $trace_level ;; esac } # set scalacOptions if we were given any -S opts [[ ${#scalac_args[@]} -eq 0 ]] || addSbt "set scalacOptions in ThisBuild += \"${scalac_args[@]}\"" # Update build.properties on disk to set explicit version - sbt gives us no choice [[ -n "$sbt_explicit_version" ]] && update_build_props_sbt "$sbt_explicit_version" vlog "Detected sbt version $sbt_version" [[ -n "$scala_version" ]] && vlog "Overriding scala version to $scala_version" # no args - alert them there's stuff in here (( argumentCount > 0 )) || { vlog "Starting $script_name: invoke with -help for other options" residual_args=( shell ) } # verify this is an sbt dir or -create was given [[ -r ./build.sbt || -d ./project || -n "$sbt_create" ]] || { cat < implements java.io.Serializable { private static final long serialVersionUID = 123L; public F first; public S second; /** * Class constructor specifying the first and second number to create * * @param first * first number * @param second * second number */ public GenericPair(F first, S second) { this.first = first; this.second = second; } /** * The method gets first number * * @return first number */ public F getFirst() { return first; } /** * The method sets first number * * @param fisrt * first number */ public void setFirst(F first) { this.first = first; } /** * The method gets second number * * @return second number */ public S getSecond() { return second; } /** * The method sets second number * * @param second * second number */ public void setSecond(S second) { this.second = second; } @Override public String toString() { return first + "," + second; } @SuppressWarnings("unchecked") public boolean equals(Object o) { if (!(o instanceof GenericPair)) return false; GenericPair p = (GenericPair)o; return (p.first).equals(first) && (p.second).equals(second); } public int hashCode() { return 17 + first.hashCode() * 31 + second.hashCode(); } } ================================================ FILE: src/main/java/com/etsy/conjecture/PrimitivePair.java ================================================ package com.etsy.conjecture; /** * PrimitivePair is JavaBean * * @author Josh Attenberg */ public class PrimitivePair implements java.io.Serializable { private static final long serialVersionUID = 1234L; public double first; public double second; /** * Class constructor specifying the first and second number to create * * @param first * first number * @param second * second number */ public PrimitivePair(double first, double second) { this.first = first; this.second = second; } /** * The method gets first number * * @return first number */ public double getFirst() { return first; } /** * The method sets first number * * @param fisrt * first number */ public void setFirst(double fisrt) { this.first = fisrt; } /** * The method gets second number * * @return second number */ public double getSecond() { return second; } /** * The method sets second number * * @param second * second number */ public void setSecond(double second) { this.second = second; } @Override public String toString() { return first + "," + second; } @Override public boolean equals(Object o) { if (!(o instanceof PrimitivePair)) return false; PrimitivePair p = (PrimitivePair)o; return p.first == first && p.second == second; } @Override public int hashCode() { return (17 + Utilities.doubleHash(first)) * 31 + Utilities.doubleHash(second); } } ================================================ FILE: src/main/java/com/etsy/conjecture/Utilities.java ================================================ package com.etsy.conjecture; import java.util.ArrayList; import java.util.Arrays; import java.util.Collection; import java.util.Collections; import java.util.Comparator; import java.util.List; import java.util.Map; import java.util.StringTokenizer; import org.apache.commons.lang.StringUtils; import com.google.common.hash.*; import com.google.common.collect.Lists; /** * class of static data science utility methods * * @author jattenberg * */ public class Utilities { public static final double SMALL = 1e-10; public static final HashFunction HASHER = Hashing.md5(); public static final double ROOT2 = Math.sqrt(2d); public static final double LOG2 = Math.log(2.); private Utilities() { } public static String cleanLine(String line) { StringBuffer buffer = new StringBuffer(); for (int i = 0; i < line.length(); i++) { char c = line.charAt(i); if (c < 128 && Character.isLetter(c)) { buffer.append(c); } else { buffer.append(' '); } } return buffer.toString().toLowerCase(); } public static String cleanLineRobust(String input, String separator, boolean ignoreNumbers) { StringBuilder buff = new StringBuilder(); StringTokenizer tokenizer = new StringTokenizer(input, " +.,~\\<>\\$?!:;(){}|" + "\b\t\n\f\r\"\'\\\\/\\=\\&\\%\\_"); while (tokenizer.hasMoreTokens()) { String token = tokenizer.nextToken(); token = token.replaceAll("-{2,}", "-"); token = token.replaceAll("^-", ""); token = token.replaceAll("-$", ""); if (token.length() < 2 || (ignoreNumbers && StringUtils.containsAny(token, "0123456789"))) continue; buff.append(token + separator); } int index = buff.lastIndexOf(separator); if (index >= 0) buff.delete(index, buff.length()); return buff.toString(); } public static String checkNotBlank(String s) { if (StringUtils.isBlank(s)) { throw new IllegalArgumentException("Argument cannot be blank"); } return s; } public static List checkNotBlank(List S) { for (String s : S) checkNotBlank(s); return S; } public static String[] checkNotBlank(String[] S) { for (String s : S) checkNotBlank(s); return S; } public static double stringInnerProduct(Map coefficients, Collection input) { double output = 0; for (String token : input) output += coefficients.containsKey(token) ? coefficients.get(token) : 0; return output; } public static double sigmoid(double operand) { return 1. / (1. + Math.exp(-operand)); } /** * derivative of the sigmoid function */ public static double dsigmoid(double operand) { return Math.exp(operand) / Math.pow(1. + Math.exp(operand), 2.); } /** * returns the strings in input in sorted order * * @param input * @return */ public static String sortTerms(String input) { return sortTerms(input, "\\s+"); } public static String sortTerms(String input, String delim) { String[] terms = input.split(delim); Arrays.sort(terms); return StringUtils.join(terms, delim); } public final static String cleanText(String tmp, int maxlen) { StringTokenizer tok = new StringTokenizer(tmp, " +.,~\\<>\\$?!:;(){}|-0123456789\b\t\n\f\r\"\'\\\\/\\=\\&\\%\\_"); StringBuilder buff = new StringBuilder(); while (tok.hasMoreTokens()) { String out = tok.nextToken(); if (out.length() < 2 || out.length() > maxlen) continue; buff.append(out + " "); } return buff.toString(); } public final static List grams(String input, int[] gramSizes, String separator) { List out = Lists.newArrayList(); StringBuilder buff = new StringBuilder(); String[] tokens = StringUtils.split(input); for (int i = 0; i < tokens.length; i++) { String token = tokens[i]; for (int len : gramSizes) { if (len > i + 1) continue; if (len == 1) { out.add(token); continue; } buff.setLength(0); for (int k = len - 1; k > 0; k--) buff.append(tokens[i - k] + separator); buff.append(token); out.add(buff.toString()); } } return out; } public static final boolean floatingPointEquals(double a, double b) { return (a - b < SMALL) && (b - a < SMALL); } public static int doubleHash(double d) { long t = Double.doubleToLongBits(d); return (int)(t ^ (t >>> 32)); } public static double logistic(double x) { return 1d / (1 + Math.exp(-x)); } static class ValueComparator> implements Comparator> { boolean reverse; public ValueComparator(boolean reverse) { this.reverse = reverse; } public int compare(Map.Entry a, Map.Entry b) { int res = a.getValue().compareTo(b.getValue()); return reverse ? -res : res; } } public static > ArrayList orderKeysByValue( Map map) { return orderKeysByValue(map, false); } public static > ArrayList orderKeysByValue( Map map, boolean reverse) { ArrayList> keys = new ArrayList>(); keys.addAll(map.entrySet()); Collections.sort(keys, new ValueComparator(reverse)); ArrayList res = new ArrayList(); for (int i = 0; i < keys.size(); i++) { res.add(keys.get(i).getKey()); } return res; } public static > List topKeysByValue( Map map, int n) { ArrayList keys = orderKeysByValue(map, true); ArrayList res = new ArrayList(n); for (int i = 0; i < n && i < keys.size(); i++) { res.add(keys.get(i)); } return res; } } ================================================ FILE: src/main/java/com/etsy/conjecture/data/AbstractInstance.java ================================================ package com.etsy.conjecture.data; import java.util.Collection; import java.util.List; import java.util.Map; public abstract class AbstractInstance> { protected static final String SEP = "___"; public String id; public String supporting_data; protected double weight; StringKeyedVector vector; public AbstractInstance() { this(new StringKeyedVector(), 1.0); } public AbstractInstance(double weight) { this(new StringKeyedVector(), weight); } public AbstractInstance(StringKeyedVector skv) { this(skv, 1.0); } public AbstractInstance(StringKeyedVector skv, double weight) { this.vector = skv; this.weight = weight; } public AbstractInstance(Map map) { this(map, 1.0); } public AbstractInstance(Map map, double weight) { this.vector = new StringKeyedVector(map); this.weight = weight; } @SuppressWarnings("unchecked") public T setWeight(double weight) { this.weight = weight; return (T)this; } public double getWeight() { return weight; } public String getId() { return id; } public StringKeyedVector getVector() { return vector; } public void setSupportingData(String s) { supporting_data = s; } public String getSupportingData() { return supporting_data; } @SuppressWarnings("unchecked") public T setCoordinate(String id, double value) { vector.setCoordinate(id, value); return (T)this; } @SuppressWarnings("unchecked") public T addToCoordinate(String id, double value) { vector.addToCoordinate(id, value); return (T)this; } @SuppressWarnings("unchecked") public T setId(String id) { this.id = id; return (T)this; } /* * (non-Javadoc) * * @see com.etsy.conjecture.data.InstanceInterface#addTerm(java.lang.String) */ @SuppressWarnings("unchecked") public T addTerm(String term) { addTerm(term, 1.); return (T)this; } /* * (non-Javadoc) * * @see com.etsy.conjecture.data.InstanceInterface#addTerm(java.lang.String, * double) */ @SuppressWarnings("unchecked") public T addTerm(String term, double featureWeight) { addToCoordinate(term, featureWeight); return (T)this; } /* * (non-Javadoc) * * @see * com.etsy.conjecture.data.InstanceInterface#addTermWithNamespace(java. * lang.String, java.lang.String) */ @SuppressWarnings("unchecked") public T addTermWithNamespace(String term, String namespace) { addTermWithNamespace(term, namespace, 1); return (T)this; } /* * (non-Javadoc) * * @see * com.etsy.conjecture.data.InstanceInterface#addTermWithNamespace(java. * lang.String, java.lang.String, double) */ @SuppressWarnings("unchecked") public T addTermWithNamespace(String term, String namespace, double featureWeight) { addToCoordinate(namespace + SEP + term, featureWeight); return (T)this; } /* * (non-Javadoc) * * @see * com.etsy.conjecture.data.InstanceInterface#addTerms(java.util.Collection, * double) */ @SuppressWarnings("unchecked") public T addTerms(Collection terms, double featureWeight) { for (String term : terms) { addToCoordinate(term, featureWeight); } return (T)this; } /* * (non-Javadoc) * * @see * com.etsy.conjecture.data.InstanceInterface#addTerms(java.util.Collection) */ @SuppressWarnings("unchecked") public T addTerms(Collection terms) { addTerms(terms, 1.); return (T)this; } /* * (non-Javadoc) * * @see * com.etsy.conjecture.data.InstanceInterface#addTermsWithNamespace(java * .util.Collection, java.lang.String, double) */ @SuppressWarnings("unchecked") public T addTermsWithNamespace(Collection terms, String namespace, double featureWeight) { for (String term : terms) { addTermWithNamespace(term, namespace, featureWeight); } return (T)this; } /* * (non-Javadoc) * * @see * com.etsy.conjecture.data.InstanceInterface#addTermsWithNamespace(java * .util.Collection, java.lang.String) */ @SuppressWarnings("unchecked") public T addTermsWithNamespace(Collection terms, String namespace) { addTermsWithNamespace(terms, namespace, 1.); return (T)this; } /* * (non-Javadoc) * * @see * com.etsy.conjecture.data.InstanceInterface#addTerms(java.lang.String[], * double) */ @SuppressWarnings("unchecked") public T addTerms(String[] terms, double featureWeight) { for (String term : terms) { addToCoordinate(term, featureWeight); } return (T)this; } /* * (non-Javadoc) * * @see * com.etsy.conjecture.data.InstanceInterface#addTerms(java.lang.String[]) */ @SuppressWarnings("unchecked") public T addTerms(String[] terms) { addTerms(terms, 1.); return (T)this; } /* * (non-Javadoc) * * @see * com.etsy.conjecture.data.InstanceInterface#addTermsWithNamespace(java * .lang.String[], java.lang.String, double) */ @SuppressWarnings("unchecked") public T addTermsWithNamespace(String[] terms, String namespace, double featureWeight) { for (String term : terms) { addTermWithNamespace(term, namespace, featureWeight); } return (T)this; } /* * (non-Javadoc) * * @see * com.etsy.conjecture.data.InstanceInterface#addTermsWithNamespace(java * .lang.String[], java.lang.String) */ @SuppressWarnings("unchecked") public T addTermsWithNamespace(String[] terms, String namespace) { addTermsWithNamespace(terms, namespace, 1.); return (T)this; } /* * (non-Javadoc) * * @see * com.etsy.conjecture.data.InstanceInterface#addTermsWithWeights(java.util * .Map) */ @SuppressWarnings("unchecked") public T addTermsWithWeights(Map termsWithWeights) { for (String term : termsWithWeights.keySet()) { addTerm(term, termsWithWeights.get(term)); } return (T)this; } /* * (non-Javadoc) * * @see * com.etsy.conjecture.data.InstanceInterface#addTermsWithWeightsWithNamespace * (java.util.Map, java.lang.String) */ @SuppressWarnings("unchecked") public T addTermsWithWeightsWithNamespace( Map termsWithWeights, String namespace) { for (String term : termsWithWeights.keySet()) { addTermWithNamespace(term, namespace, termsWithWeights.get(term)); } return (T)this; } /* * (non-Javadoc) * * @see * com.etsy.conjecture.data.InstanceInterface#addNumericArrayWithNamespace * (double[], java.lang.String) */ @SuppressWarnings("unchecked") public T addNumericArrayWithNamespace(double[] array, String namespace) { for (int i = 0; i < array.length; i++) { addToCoordinate(namespace + SEP + i, array[i]); } return (T)this; } /* * (non-Javadoc) * * @see com.etsy.conjecture.data.InstanceInterface#addNumericArray(double[]) */ @SuppressWarnings("unchecked") public T addNumericArray(double[] array) { for (int i = 0; i < array.length; i++) { addToCoordinate("" + i, array[i]); } return (T)this; } /* * (non-Javadoc) * * @see * com.etsy.conjecture.data.InstanceInterface#addNumericArrayWithNamespace * (java.lang.Double[], java.lang.String) */ @SuppressWarnings("unchecked") public T addNumericArrayWithNamespace(Double[] array, String namespace) { for (int i = 0; i < array.length; i++) { addToCoordinate(namespace + SEP + i, array[i]); } return (T)this; } /* * (non-Javadoc) * * @see * com.etsy.conjecture.data.InstanceInterface#addNumericArray(java.lang. * Double[]) */ @SuppressWarnings("unchecked") public T addNumericArray(Double[] array) { for (int i = 0; i < array.length; i++) { addToCoordinate("" + i, array[i]); } return (T)this; } /* * (non-Javadoc) * * @see * com.etsy.conjecture.data.InstanceInterface#addNumericArrayWithNamespace * (java.util.List, java.lang.String) */ @SuppressWarnings("unchecked") public T addNumericArrayWithNamespace(List values, String namespace) { for (int i = 0; i < values.size(); i++) { addToCoordinate(namespace + SEP + i, values.get(i)); } return (T)this; } /* * (non-Javadoc) * * @see * com.etsy.conjecture.data.InstanceInterface#addNumericArray(java.util. * List) */ @SuppressWarnings("unchecked") public T addNumericArray(List values) { for (int i = 0; i < values.size(); i++) { addToCoordinate("" + i, values.get(i)); } return (T)this; } /* * (non-Javadoc) * * @see * com.etsy.conjecture.data.InstanceInterface#setNumericArrayWithNamespace * (double[], java.lang.String) */ @SuppressWarnings("unchecked") public T setNumericArrayWithNamespace(double[] array, String namespace) { for (int i = 0; i < array.length; i++) { addToCoordinate(namespace + SEP + i, array[i]); } return (T)this; } /* * (non-Javadoc) * * @see com.etsy.conjecture.data.InstanceInterface#setNumericArray(double[]) */ @SuppressWarnings("unchecked") public T setNumericArray(double[] array) { for (int i = 0; i < array.length; i++) { addToCoordinate("" + i, array[i]); } return (T)this; } /* * (non-Javadoc) * * @see * com.etsy.conjecture.data.InstanceInterface#setNumericArrayWithNamespace * (java.lang.Double[], java.lang.String) */ @SuppressWarnings("unchecked") public T setNumericArrayWithNamespace(Double[] array, String namespace) { for (int i = 0; i < array.length; i++) { addToCoordinate(namespace + SEP + i, array[i]); } return (T)this; } /* * (non-Javadoc) * * @see * com.etsy.conjecture.data.InstanceInterface#setNumericArray(java.lang. * Double[]) */ @SuppressWarnings("unchecked") public T setNumericArray(Double[] array) { for (int i = 0; i < array.length; i++) { addToCoordinate("" + i, array[i]); } return (T)this; } /* * (non-Javadoc) * * @see * com.etsy.conjecture.data.InstanceInterface#setNumericArrayWithNamespace * (java.util.List, java.lang.String) */ @SuppressWarnings("unchecked") public T setNumericArrayWithNamespace(List values, String namespace) { for (int i = 0; i < values.size(); i++) { addToCoordinate(namespace + SEP + i, values.get(i)); } return (T)this; } /* * (non-Javadoc) * * @see * com.etsy.conjecture.data.InstanceInterface#setNumericArray(java.util. * List) */ @SuppressWarnings("unchecked") public T setNumericArray(List values) { for (int i = 0; i < values.size(); i++) { addToCoordinate("" + i, values.get(i)); } return (T)this; } /* * (non-Javadoc) * * @see com.etsy.conjecture.data.InstanceInterface#addIdField(long, double) */ @SuppressWarnings("unchecked") public T addIdField(long id, double featureWeight) { addToCoordinate("" + id, featureWeight); return (T)this; } /* * (non-Javadoc) * * @see com.etsy.conjecture.data.InstanceInterface#addIdField(long) */ @SuppressWarnings("unchecked") public T addIdField(long id) { addIdField(id, 1.); return (T)this; } /* * (non-Javadoc) * * @see * com.etsy.conjecture.data.InstanceInterface#addIdFieldWithNamespace(long, * double, java.lang.String) */ @SuppressWarnings("unchecked") public T addIdFieldWithNamespace(long id, double featureWeight, String namespace) { addToCoordinate(namespace + SEP + id, featureWeight); return (T)this; } /* * (non-Javadoc) * * @see * com.etsy.conjecture.data.InstanceInterface#addIdFieldWithNamespace(long, * java.lang.String) */ @SuppressWarnings("unchecked") public T addIdFieldWithNamespace(long id, String namespace) { addIdFieldWithNamespace(id, 1., namespace); return (T)this; } /* * (non-Javadoc) * * @see com.etsy.conjecture.data.InstanceInterface#addIdField(int, double) */ @SuppressWarnings("unchecked") public T addIdField(int id, double featureWeight) { addToCoordinate("" + id, featureWeight); return (T)this; } /* * (non-Javadoc) * * @see com.etsy.conjecture.data.InstanceInterface#addIdField(int) */ @SuppressWarnings("unchecked") public T addIdField(int id) { addIdField(id, 1.); return (T)this; } /* * (non-Javadoc) * * @see * com.etsy.conjecture.data.InstanceInterface#addIdFieldWithNamespace(int, * double, java.lang.String) */ @SuppressWarnings("unchecked") public T addIdFieldWithNamespace(int id, double featureWeight, String namespace) { addToCoordinate(namespace + SEP + id, featureWeight); return (T)this; } /* * (non-Javadoc) * * @see * com.etsy.conjecture.data.InstanceInterface#addIdFieldWithNamespace(int, * java.lang.String) */ @SuppressWarnings("unchecked") public T addIdFieldWithNamespace(int id, String namespace) { addIdFieldWithNamespace(id, 1., namespace); return (T)this; } /* * (non-Javadoc) * * @see com.etsy.conjecture.data.InstanceInterface#addIds(long[], double) */ @SuppressWarnings("unchecked") public T addIds(long[] ids, double featureWeight) { for (long id : ids) { addToCoordinate("" + id, featureWeight); } return (T)this; } /* * (non-Javadoc) * * @see com.etsy.conjecture.data.InstanceInterface#addIds(long[]) */ @SuppressWarnings("unchecked") public T addIds(long[] ids) { addIds(ids, 1.); return (T)this; } /* * (non-Javadoc) * * @see com.etsy.conjecture.data.InstanceInterface#addIds(int[], double) */ @SuppressWarnings("unchecked") public T addIds(int[] ids, double featureWeight) { for (long id : ids) { addToCoordinate("" + id, featureWeight); } return (T)this; } /* * (non-Javadoc) * * @see com.etsy.conjecture.data.InstanceInterface#addIds(int[]) */ @SuppressWarnings("unchecked") public T addIds(int[] ids) { addIds(ids, 1.); return (T)this; } /* * (non-Javadoc) * * @see * com.etsy.conjecture.data.InstanceInterface#addIds(java.util.Collection, * double) */ @SuppressWarnings("unchecked") public T addIds(Collection ids, double featureWeight) { for (long id : ids) { addToCoordinate("" + id, featureWeight); } return (T)this; } /* * (non-Javadoc) * * @see * com.etsy.conjecture.data.InstanceInterface#addIds(java.util.Collection) */ @SuppressWarnings("unchecked") public T addIds(Collection ids) { addIds(ids, 1.); return (T)this; } /* * (non-Javadoc) * * @see * com.etsy.conjecture.data.InstanceInterface#addIdsWithNamespace(long[], * double, java.lang.String) */ @SuppressWarnings("unchecked") public T addIdsWithNamespace(long[] ids, double featureWeight, String namespace) { for (long id : ids) { addToCoordinate(namespace + SEP + id, featureWeight); } return (T)this; } /* * (non-Javadoc) * * @see * com.etsy.conjecture.data.InstanceInterface#addIdsWithNamespace(long[], * java.lang.String) */ @SuppressWarnings("unchecked") public T addIdsWithNamespace(long[] ids, String namespace) { addIdsWithNamespace(ids, 1., namespace); return (T)this; } /* * (non-Javadoc) * * @see * com.etsy.conjecture.data.InstanceInterface#addIdsWithNamespace(int[], * double, java.lang.String) */ @SuppressWarnings("unchecked") public T addIdsWithNamespace(int[] ids, double featureWeight, String namespace) { for (int id : ids) { addToCoordinate(namespace + SEP + id, featureWeight); } return (T)this; } /* * (non-Javadoc) * * @see * com.etsy.conjecture.data.InstanceInterface#addIdsWithNamespace(int[], * java.lang.String) */ @SuppressWarnings("unchecked") public T addIdsWithNamespace(int[] ids, String namespace) { addIdsWithNamespace(ids, 1., namespace); return (T)this; } /* * (non-Javadoc) * * @see * com.etsy.conjecture.data.InstanceInterface#addIdsWithNamespace(java.util * .Collection, double, java.lang.String) */ @SuppressWarnings("unchecked") public T addIdsWithNamespace(Collection ids, double featureWeight, String namespace) { for (Long id : ids) { addToCoordinate(namespace + SEP + id, featureWeight); } return (T)this; } /* * (non-Javadoc) * * @see * com.etsy.conjecture.data.InstanceInterface#addIdsWithNamespace(java.util * .Collection, java.lang.String) */ @SuppressWarnings("unchecked") public T addIdsWithNamespace(Collection ids, String namespace) { addIdsWithNamespace(ids, 1., namespace); return (T)this; } /* * (non-Javadoc) * * @see com.etsy.conjecture.data.InstanceInterface#setIdField(long, double) */ @SuppressWarnings("unchecked") public T setIdField(long id, double featureWeight) { addToCoordinate("" + id, featureWeight); return (T)this; } /* * (non-Javadoc) * * @see com.etsy.conjecture.data.InstanceInterface#setIdField(long) */ @SuppressWarnings("unchecked") public T setIdField(long id) { setIdField(id, 1.); return (T)this; } /* * (non-Javadoc) * * @see * com.etsy.conjecture.data.InstanceInterface#setIdFieldWithNamespace(long, * double, java.lang.String) */ @SuppressWarnings("unchecked") public T setIdFieldWithNamespace(long id, double featureWeight, String namespace) { addToCoordinate(namespace + SEP + id, featureWeight); return (T)this; } /* * (non-Javadoc) * * @see * com.etsy.conjecture.data.InstanceInterface#setIdFieldWithNamespace(long, * java.lang.String) */ @SuppressWarnings("unchecked") public T setIdFieldWithNamespace(long id, String namespace) { setIdFieldWithNamespace(id, 1., namespace); return (T)this; } /* * (non-Javadoc) * * @see com.etsy.conjecture.data.InstanceInterface#setIdField(int, double) */ @SuppressWarnings("unchecked") public T setIdField(int id, double featureWeight) { addToCoordinate("" + id, featureWeight); return (T)this; } /* * (non-Javadoc) * * @see com.etsy.conjecture.data.InstanceInterface#setIdField(int) */ @SuppressWarnings("unchecked") public T setIdField(int id) { setIdField(id, 1.); return (T)this; } /* * (non-Javadoc) * * @see * com.etsy.conjecture.data.InstanceInterface#setIdFieldWithNamespace(int, * double, java.lang.String) */ @SuppressWarnings("unchecked") public T setIdFieldWithNamespace(int id, double featureWeight, String namespace) { addToCoordinate(namespace + SEP + id, featureWeight); return (T)this; } /* * (non-Javadoc) * * @see * com.etsy.conjecture.data.InstanceInterface#setIdFieldWithNamespace(int, * java.lang.String) */ @SuppressWarnings("unchecked") public T setIdFieldWithNamespace(int id, String namespace) { setIdFieldWithNamespace(id, 1., namespace); return (T)this; } /* * (non-Javadoc) * * @see com.etsy.conjecture.data.InstanceInterface#setIds(long[], double) */ @SuppressWarnings("unchecked") public T setIds(long[] ids, double featureWeight) { for (long id : ids) { addToCoordinate("" + id, featureWeight); } return (T)this; } /* * (non-Javadoc) * * @see com.etsy.conjecture.data.InstanceInterface#setIds(long[]) */ @SuppressWarnings("unchecked") public T setIds(long[] ids) { setIds(ids, 1.); return (T)this; } /* * (non-Javadoc) * * @see com.etsy.conjecture.data.InstanceInterface#setIds(int[], double) */ @SuppressWarnings("unchecked") public T setIds(int[] ids, double featureWeight) { for (long id : ids) { addToCoordinate("" + id, featureWeight); } return (T)this; } /* * (non-Javadoc) * * @see com.etsy.conjecture.data.InstanceInterface#setIds(int[]) */ @SuppressWarnings("unchecked") public T setIds(int[] ids) { setIds(ids, 1.); return (T)this; } /* * (non-Javadoc) * * @see * com.etsy.conjecture.data.InstanceInterface#setIds(java.util.Collection, * double) */ @SuppressWarnings("unchecked") public T setIds(Collection ids, double featureWeight) { for (long id : ids) { addToCoordinate("" + id, featureWeight); } return (T)this; } /* * (non-Javadoc) * * @see * com.etsy.conjecture.data.InstanceInterface#setIds(java.util.Collection) */ @SuppressWarnings("unchecked") public T setIds(Collection ids) { setIds(ids, 1.); return (T)this; } /* * (non-Javadoc) * * @see * com.etsy.conjecture.data.InstanceInterface#setIdsWithNamespace(long[], * double, java.lang.String) */ @SuppressWarnings("unchecked") public T setIdsWithNamespace(long[] ids, double featureWeight, String namespace) { for (long id : ids) { addToCoordinate(namespace + SEP + id, featureWeight); } return (T)this; } /* * (non-Javadoc) * * @see * com.etsy.conjecture.data.InstanceInterface#setIdsWithNamespace(long[], * java.lang.String) */ @SuppressWarnings("unchecked") public T setIdsWithNamespace(long[] ids, String namespace) { setIdsWithNamespace(ids, 1., namespace); return (T)this; } /* * (non-Javadoc) * * @see * com.etsy.conjecture.data.InstanceInterface#setIdsWithNamespace(int[], * double, java.lang.String) */ @SuppressWarnings("unchecked") public T setIdsWithNamespace(int[] ids, double featureWeight, String namespace) { for (int id : ids) { addToCoordinate(namespace + SEP + id, featureWeight); } return (T)this; } /* * (non-Javadoc) * * @see * com.etsy.conjecture.data.InstanceInterface#setIdsWithNamespace(int[], * java.lang.String) */ @SuppressWarnings("unchecked") public T setIdsWithNamespace(int[] ids, String namespace) { setIdsWithNamespace(ids, 1., namespace); return (T)this; } /* * (non-Javadoc) * * @see * com.etsy.conjecture.data.InstanceInterface#setIdsWithNamespace(java.util * .Collection, double, java.lang.String) */ @SuppressWarnings("unchecked") public T setIdsWithNamespace(Collection ids, double featureWeight, String namespace) { for (Long id : ids) { addToCoordinate(namespace + SEP + id, featureWeight); } return (T)this; } /* * (non-Javadoc) * * @see * com.etsy.conjecture.data.InstanceInterface#setIdsWithNamespace(java.util * .Collection, java.lang.String) */ @SuppressWarnings("unchecked") public T setIdsWithNamespace(Collection ids, String namespace) { setIdsWithNamespace(ids, 1., namespace); return (T)this; } } ================================================ FILE: src/main/java/com/etsy/conjecture/data/BinaryLabel.java ================================================ package com.etsy.conjecture.data; import static com.google.common.base.Preconditions.checkArgument; public class BinaryLabel extends RealValuedLabel { private static final long serialVersionUID = 1L; public BinaryLabel() { super(0.0); } public BinaryLabel(double value) { super(checkBinaryValue(value)); } private static double checkBinaryValue(double value) { checkArgument(value >= 0 && value <= 1, "value must be in [0, 1], given: %s", value); return value; } // {0,+1} -> {-1,+1} public double getAsPlusMinus() { return 2.0 * (getValue() - 0.5); } } ================================================ FILE: src/main/java/com/etsy/conjecture/data/BinaryLabeledInstance.java ================================================ package com.etsy.conjecture.data; import java.util.Map; /** * TODO: when using method string all methods return a RealValueLabeledInstance * think about how to avoid this while not using generic types */ public class BinaryLabeledInstance extends AbstractInstance implements LabeledInstance { protected BinaryLabel label; public BinaryLabel getLabel() { return label; } public BinaryLabeledInstance() { this(new BinaryLabel(0.0), 1.0); } public BinaryLabeledInstance(double label, Map instance) { this(new BinaryLabel(label), instance, 1.0); } public BinaryLabeledInstance(double label, Map instance, double weight) { this(new BinaryLabel(label), instance, weight); } public BinaryLabeledInstance(double label, StringKeyedVector vec) { this(new BinaryLabel(label), vec.getMap(), 1.0); } public BinaryLabeledInstance(double label, StringKeyedVector vec, double weight) { this(new BinaryLabel(label), vec.getMap(), weight); } public BinaryLabeledInstance(BinaryLabel label, Map instance) { this(label, instance, 1.0); } public BinaryLabeledInstance(BinaryLabel label, Map instance, double weight) { super(instance, weight); this.label = label; } public BinaryLabeledInstance(BinaryLabel label, StringKeyedVector vec) { this(label, vec.getMap(), 1.0); } public BinaryLabeledInstance(BinaryLabel label, StringKeyedVector vec, double weight) { this(label, vec.getMap(), weight); } public BinaryLabeledInstance(double label) { this(new BinaryLabel(label), 1.0); } public BinaryLabeledInstance(double label, double weight) { this(new BinaryLabel(label), weight); } public BinaryLabeledInstance(BinaryLabel label) { this(label, 1.0); } public BinaryLabeledInstance(BinaryLabel label, double weight) { super(weight); this.label = label; } } ================================================ FILE: src/main/java/com/etsy/conjecture/data/ByteArrayDoubleHashMap.java ================================================ package com.etsy.conjecture.data; import gnu.trove.function.TDoubleFunction; import gnu.trove.iterator.TObjectDoubleIterator; import gnu.trove.map.hash.TObjectDoubleHashMap; import java.io.IOException; import java.io.ObjectInputStream; import java.io.ObjectOutputStream; import java.io.Serializable; import java.io.UnsupportedEncodingException; import java.util.AbstractMap; import java.util.Arrays; import java.util.HashSet; import java.util.Iterator; import java.util.Map; import java.util.Set; import com.esotericsoftware.kryo.Kryo; import com.esotericsoftware.kryo.KryoSerializable; import com.esotericsoftware.kryo.io.Input; import com.esotericsoftware.kryo.io.Output; public class ByteArrayDoubleHashMap implements Serializable, KryoSerializable, Iterable>, Map { private static final long serialVersionUID = -7070522686694887436L; // - represent the sparse map by a mapping of coordinate name strings // (feature names) // to doubles. protected TObjectDoubleHashMap map; protected String keyEncoding; protected float loadFactor; protected double defaultValue; public ByteArrayDoubleHashMap() { this(10, 0.8f, 0.0); } public ByteArrayDoubleHashMap(int initialCapacity, float loadFactor, double defaultValue) { this(initialCapacity, loadFactor, "ASCII", defaultValue); } public ByteArrayDoubleHashMap(int initialCapacity, float loadFactor, String keyEncoding, double defaultValue) { this.map = new TByteArrayDoubleHashMap(initialCapacity, loadFactor, defaultValue); this.keyEncoding = keyEncoding; this.loadFactor = loadFactor; this.defaultValue = defaultValue; } public String byteArrayToString(byte[] b) { try { return new String(b, keyEncoding); } catch (UnsupportedEncodingException e) { e.printStackTrace(); return null; } } public byte[] stringToByteArray(String s) { try { return s.getBytes(keyEncoding); } catch (UnsupportedEncodingException e) { e.printStackTrace(); return null; } } /** * Customized trove hashmap which does both: customized hash/equality * functions, and also storing the values as a primitive array. */ static class TByteArrayDoubleHashMap extends TObjectDoubleHashMap { public TByteArrayDoubleHashMap(int initialSize, float loadFactor, double defaultValue) { super(initialSize, loadFactor, defaultValue); } protected int hash(Object obj) { return Arrays.hashCode((byte[])obj); } protected boolean equals(Object a, Object b) { return b != null && b != REMOVED && Arrays.equals((byte[])a, (byte[])b); } // - ovrride this to prevent doubling on resize. public double put(byte[] key, double value) { int index = insertKey(key); double previous = 0.0; boolean isNewMapping = true; if (index < 0) { index = -index - 1; previous = _values[index]; isNewMapping = false; } _values[index] = value; if (isNewMapping) { postInsertHook2(consumeFreeSlot); } return previous; } protected final void postInsertHook2(boolean usedFreeSlot) { if (usedFreeSlot) { _free--; } if (++_size > _maxSize || _free == 0) { int newCapacity = _size > _maxSize ? gnu.trove.impl.PrimeFinder .nextPrime((int)(capacity() * 1.2) + 10) : capacity(); if (newCapacity > 1000000) { System.out.println("rehashing to size: " + newCapacity + " from " + capacity()); } rehash(newCapacity); computeMaxSize(capacity()); } } } public int size() { return map.size(); } public boolean containsKey(Object key) { if (key instanceof byte[]) { return map.containsKey(key); } else if (key instanceof String) { return map.containsKey(stringToByteArray((String)key)); } else { throw new IllegalArgumentException("class " + key.getClass().toString() + " is not valid for ByteArrayDoubleHashMap.containsKey"); } } public Set keySet() { Set res = new HashSet(); for (byte[] b : map.keySet()) { res.add(byteArrayToString(b)); } return res; } public Set values() { Set values = new HashSet(); for (Map.Entry e : this) { values.add(e.getValue()); } return values; } public boolean containsValue(Object d) { return values().contains((Double)d); } public Set> entrySet() { Set> entries = new HashSet>(); for (Map.Entry e : this) { entries.add(e); } return entries; } public boolean isEmpty() { return size() > 0; } public void clear() { map.clear(); } public Double remove(Object k) { return removePrimitive((String)k); } public Double get(Object k) { return getPrimitive((String)k); } public Double put(String key, Double value) { return putPrimitive(key, value); } public void putAll(Map m) { for (Map.Entry e : m.entrySet()) { put((String)e.getKey(), (Double)e.getValue()); } } public double getPrimitive(byte[] key) { return map.get(key); } public double getPrimitive(String key) { return map.get(stringToByteArray(key)); } public double putPrimitive(byte[] key, double value) { return map.put(key, value); } public double putPrimitive(String key, double value) { return map.put(stringToByteArray(key), value); } public double removePrimitive(byte[] key) { return map.remove(key); } public double removePrimitive(String key) { return map.remove(stringToByteArray(key)); } public void transformValues(TDoubleFunction func) { map.transformValues(func); } public TObjectDoubleIterator troveIterator() { return map.iterator(); } public Iterator> iterator() { return new Iterator>() { private TObjectDoubleIterator iter = troveIterator(); public boolean hasNext() { return iter.hasNext(); } public void remove() { iter.remove(); } public Map.Entry next() { iter.advance(); return new AbstractMap.SimpleImmutableEntry( byteArrayToString(iter.key()), iter.value()); } }; } // - java serialization private void writeObject(ObjectOutputStream output) throws IOException { output.writeObject(keyEncoding); output.writeFloat(loadFactor); output.writeDouble(defaultValue); output.writeInt(map.size()); for (TObjectDoubleIterator it = map.iterator(); it.hasNext();) { it.advance(); byte[] key = it.key(); output.writeInt(key.length); for (int i = 0; i < key.length; i++) { output.writeByte(key[i]); } output.writeDouble(it.value()); } } private void readObject(ObjectInputStream input) throws IOException, ClassNotFoundException { keyEncoding = (String)input.readObject(); loadFactor = input.readFloat(); defaultValue = input.readDouble(); int size = input.readInt(); map = new TByteArrayDoubleHashMap(size, loadFactor, defaultValue); for (int i = 0; i < size; i++) { int length = input.readInt(); byte[] key = new byte[length]; for (int j = 0; j < length; j++) { key[j] = input.readByte(); } double value = input.readDouble(); map.put(key, value); } } // - kryo serialization for use in scalding. public void write(Kryo kryo, Output output) { output.writeString(keyEncoding); output.writeFloat(loadFactor); output.writeDouble(defaultValue); output.writeInt(map.size()); for (TObjectDoubleIterator it = map.iterator(); it.hasNext();) { it.advance(); byte[] key = it.key(); output.writeInt(key.length); for (int i = 0; i < key.length; i++) { output.writeByte(key[i]); } output.writeDouble(it.value()); } } public void read(Kryo kryo, Input input) { keyEncoding = input.readString(); loadFactor = input.readFloat(); defaultValue = input.readDouble(); int size = input.readInt(); map = new TByteArrayDoubleHashMap(size, loadFactor, defaultValue); for (int i = 0; i < size; i++) { int length = input.readInt(); byte[] key = new byte[length]; for (int j = 0; j < length; j++) { key[j] = input.readByte(); } double value = input.readDouble(); map.put(key, value); } } } ================================================ FILE: src/main/java/com/etsy/conjecture/data/ClusterLabel.java ================================================ package com.etsy.conjecture.data; public class ClusterLabel extends Label{ private static final long serialVersionUID = 1L; protected String label; public ClusterLabel() { this(null); } public ClusterLabel(String label) { this.label = label; } public String getLabel() { return this.label; } public void setLabel(String label) { this.label = label; } public String toString() { return label; } @Override public int hashCode() { final int prime = 31; int result = 1; result = prime * result + ((label == null) ? 0 : label.hashCode()); return result; } @Override public boolean equals(Object obj) { if (this == obj) return true; if (obj == null) return false; if (getClass() != obj.getClass()) return false; ClusterLabel other = (ClusterLabel) obj; if (label == null) { if (other.label != null) return false; } else if (!label.equals(other.label)) return false; return true; } } ================================================ FILE: src/main/java/com/etsy/conjecture/data/ClusterPrediction.java ================================================ package com.etsy.conjecture.data; import java.util.Map; import com.google.common.collect.Maps; /** * Representing a probability of membership in each cluster */ public class ClusterPrediction extends ClusterLabel{ private static final long serialVersionUID = -1L; /** * Cluster membership probabilities */ private Map clusterProbs; public ClusterPrediction(Map clusterProbs) { this.clusterProbs = Maps.newHashMap(clusterProbs); boolean first = true; double maxProb = 0; String maxCategory = null; for (String key : clusterProbs.keySet()) { if(first || clusterProbs.get(key) > maxProb) { maxProb = clusterProbs.get(key); maxCategory = key; first = false; } } setLabel(maxCategory); } public Map getMap() { return clusterProbs; } } ================================================ FILE: src/main/java/com/etsy/conjecture/data/Instance.java ================================================ package com.etsy.conjecture.data; //TODO: reset methods for string adders //TODO: for instance, vector subtraction? public class Instance extends AbstractInstance { public Instance() { super(); } public Instance(StringKeyedVector vec) { super(vec); } } ================================================ FILE: src/main/java/com/etsy/conjecture/data/InstanceFactory.java ================================================ package com.etsy.conjecture.data; public class InstanceFactory { private InstanceFactory() { }; public static Instance buildInstance() { return new Instance(); } public static Instance copyInstance(Instance inst) { return new Instance(inst.getVector()); } public static BinaryLabeledInstance toBinaryLabeledInstance(double label, Instance instance) { return new BinaryLabeledInstance(label, instance.getVector()); } public static BinaryLabeledInstance toBinaryLabeledInstance( BinaryLabel label, Instance instance) { return new BinaryLabeledInstance(label, instance.getVector()); } public static RealValueLabeledInstance toRealValueLabeledInstance( double label, Instance instance) { return new RealValueLabeledInstance(label, instance.getVector()); } public static RealValueLabeledInstance toRealValueLabeledInstance( RealValuedLabel label, Instance instance) { return new RealValueLabeledInstance(label, instance.getVector()); } } ================================================ FILE: src/main/java/com/etsy/conjecture/data/InstanceInterface.java ================================================ package com.etsy.conjecture.data; import java.util.Collection; import java.util.List; import java.util.Map; public interface InstanceInterface> { public abstract String getId(); public abstract T setId(String id); public abstract T addTerm(String term); public abstract T addTerm(String term, double featureWeight); public abstract T addTermWithNamespace(String term, String namespace); public abstract T addTermWithNamespace(String term, String namespace, double featureWeight); public abstract T addTerms(Collection terms, double featureWeight); public abstract T addTerms(Collection terms); public abstract T addTermsWithNamespace(Collection terms, String namespace, double featureWeight); public abstract T addTermsWithNamespace(Collection terms, String namespace); public abstract T addTerms(String[] terms, double featureWeight); public abstract T addTerms(String[] terms); public abstract T addTermsWithNamespace(String[] terms, String namespace, double featureWeight); public abstract T addTermsWithNamespace(String[] terms, String namespace); public abstract T addTermsWithWeights(Map termsWithWeights); public abstract T addTermsWithWeightsWithNamespace( Map termsWithWeights, String namespace); public abstract T addNumericArrayWithNamespace(double[] array, String namespace); public abstract T addNumericArray(double[] array); public abstract T addNumericArrayWithNamespace(Double[] array, String namespace); public abstract T addNumericArray(Double[] array); public abstract T addNumericArrayWithNamespace(List values, String namespace); public abstract T addNumericArray(List values); public abstract T setNumericArrayWithNamespace(double[] array, String namespace); public abstract T setNumericArray(double[] array); public abstract T setNumericArrayWithNamespace(Double[] array, String namespace); public abstract T setNumericArray(Double[] array); public abstract T setNumericArrayWithNamespace(List values, String namespace); public abstract T setNumericArray(List values); public abstract T addIdField(long id, double featureWeight); public abstract T addIdField(long id); public abstract T addIdFieldWithNamespace(long id, double featureWeight, String namespace); public abstract T addIdFieldWithNamespace(long id, String namespace); public abstract T addIdField(int id, double featureWeight); public abstract T addIdField(int id); public abstract T addIdFieldWithNamespace(int id, double featureWeight, String namespace); public abstract T addIdFieldWithNamespace(int id, String namespace); public abstract T addIds(long[] ids, double featureWeight); public abstract T addIds(long[] ids); public abstract T addIds(int[] ids, double featureWeight); public abstract T addIds(int[] ids); public abstract T addIds(Collection ids, double featureWeight); public abstract T addIds(Collection ids); public abstract T addIdsWithNamespace(long[] ids, double featureWeight, String namespace); public abstract T addIdsWithNamespace(long[] ids, String namespace); public abstract T addIdsWithNamespace(int[] ids, double featureWeight, String namespace); public abstract T addIdsWithNamespace(int[] ids, String namespace); public abstract T addIdsWithNamespace(Collection ids, double featureWeight, String namespace); public abstract T addIdsWithNamespace(Collection ids, String namespace); public abstract T setIdField(long id, double featureWeight); public abstract T setIdField(long id); public abstract T setIdFieldWithNamespace(long id, double featureWeight, String namespace); public abstract T setIdFieldWithNamespace(long id, String namespace); public abstract T setIdField(int id, double featureWeight); public abstract T setIdField(int id); public abstract T setIdFieldWithNamespace(int id, double featureWeight, String namespace); public abstract T setIdFieldWithNamespace(int id, String namespace); public abstract T setIds(long[] ids, double featureWeight); public abstract T setIds(long[] ids); public abstract T setIds(int[] ids, double featureWeight); public abstract T setIds(int[] ids); public abstract T setIds(Collection ids, double featureWeight); public abstract T setIds(Collection ids); public abstract T setIdsWithNamespace(long[] ids, double featureWeight, String namespace); public abstract T setIdsWithNamespace(long[] ids, String namespace); public abstract T setIdsWithNamespace(int[] ids, double featureWeight, String namespace); public abstract T setIdsWithNamespace(int[] ids, String namespace); public abstract T setIdsWithNamespace(Collection ids, double featureWeight, String namespace); public abstract T setIdsWithNamespace(Collection ids, String namespace); } ================================================ FILE: src/main/java/com/etsy/conjecture/data/Label.java ================================================ package com.etsy.conjecture.data; public class Label implements java.io.Serializable { private static final long serialVersionUID = 1L; public Label() { } } ================================================ FILE: src/main/java/com/etsy/conjecture/data/LabeledInstance.java ================================================ package com.etsy.conjecture.data; public interface LabeledInstance { public L getLabel(); public StringKeyedVector getVector(); public double getWeight(); } ================================================ FILE: src/main/java/com/etsy/conjecture/data/LazyVector.java ================================================ package com.etsy.conjecture.data; import gnu.trove.function.TDoubleFunction; import gnu.trove.iterator.TObjectDoubleIterator; import java.io.IOException; import java.io.ObjectInputStream; import java.io.ObjectOutputStream; import java.io.Serializable; import java.util.ArrayList; import java.util.Iterator; import java.util.Map; import java.util.Set; import com.esotericsoftware.kryo.Kryo; import com.esotericsoftware.kryo.KryoSerializable; import com.esotericsoftware.kryo.io.Input; import com.esotericsoftware.kryo.io.Output; import com.etsy.conjecture.Utilities; public class LazyVector extends StringKeyedVector implements Serializable, KryoSerializable { private static final long serialVersionUID = -7070522686694887436L; protected transient ByteArrayDoubleHashMap iterations; protected long iteration = 0; protected UpdateFunction updater; /** * The function used to update the parameters during the lazy update */ public static interface UpdateFunction extends Serializable { public double lazyUpdate(String key, double param, long startIteration, long endIteration); } public LazyVector() { this(new UpdateFunction() { private static final long serialVersionUID = 1740773207106961880L; public double lazyUpdate(String key, double p, long a, long b) { return p; } }); } public LazyVector(UpdateFunction uf) { this(10, uf); } public LazyVector(int initialCapacity, UpdateFunction uf) { super(initialCapacity); iterations = new ByteArrayDoubleHashMap(initialCapacity, LOAD_FACTOR, FEATURE_ENCODING, 0.0); updater = uf; } public LazyVector(StringKeyedVector skv, UpdateFunction uf) { if (skv instanceof LazyVector) { ((LazyVector)skv).delazify(); } this.vector = skv.vector; iterations = new ByteArrayDoubleHashMap(skv.size(), LOAD_FACTOR, FEATURE_ENCODING, 0.0); updater = uf; } public LazyVector(ByteArrayDoubleHashMap map, UpdateFunction uf) { super(map); iterations = new ByteArrayDoubleHashMap(10, LOAD_FACTOR, FEATURE_ENCODING, 0.0); updater = uf; } public LazyVector(Map jmap, UpdateFunction uf) { super(jmap); iterations = new ByteArrayDoubleHashMap(10, LOAD_FACTOR, FEATURE_ENCODING, 0.0); updater = uf; } public void incrementIteration() { iteration++; } public void delazify() { for (TObjectDoubleIterator it = vector.troveIterator(); it .hasNext();) { it.advance(); long startIter = (long)iterations.getPrimitive(it.key()); // defaults // to 0.0 if (startIter < iteration) { it.setValue(updater.lazyUpdate(it.key().toString(), it.value(), startIter, iteration)); iterations.putPrimitive(it.key(), (double)iteration); } } removeZeroCoordinates(); } public double delazifyCoordinate(String key) { return delazifyCoordinate(vector.stringToByteArray(key)); } public double delazifyCoordinate(byte[] key) { if (vector.containsKey(key)) { long oldIteration = (long)iterations.getPrimitive(key); double initial = vector.getPrimitive(key); if (oldIteration < iteration) { double updated = updater.lazyUpdate(key.toString(), initial, oldIteration, iteration); if (Utilities.floatingPointEquals(updated, 0.0d)) { vector.removePrimitive(key); iterations.removePrimitive(key); } else { iterations.putPrimitive(key, (double)iteration); vector.putPrimitive(key, updated); } return updated; } else { return initial; } } return 0.0; } public void skipToIteration(long iter) { delazify(); iteration = iter; for (TObjectDoubleIterator it = iterations.troveIterator(); it .hasNext();) { it.advance(); it.setValue((double)iter); } } /** * disregards prior value at a particular key, replacing with the specified * value. */ public double setCoordinate(String key, double value) { if (Utilities.floatingPointEquals(value, 0d)) { return deleteCoordinate(key); } else if (!freezeKeySet) { vector.putPrimitive(key, value); iterations.putPrimitive(key, (double)iteration); } return 0d; } /** * remove a coordinate from the vector (same as setting it to 0). */ public double deleteCoordinate(String key) { if (vector.containsKey(key) && !freezeKeySet) { iterations.removePrimitive(key); return vector.removePrimitive(key); } else { return 0d; } } public Map getMap() { return vector; } protected double addToCoordinateInternal(byte[] bkey, double value) { delazifyCoordinate(bkey); if (vector.containsKey(bkey)) { double updated = vector.getPrimitive(bkey) + value; if (Utilities.floatingPointEquals(updated, 0.0d)) { iterations.removePrimitive(bkey); return vector.removePrimitive(bkey); } else { iterations.putPrimitive(bkey, (double)iteration); return vector.putPrimitive(bkey, updated); } } else if (!freezeKeySet && !Utilities.floatingPointEquals(value, 0.0d)) { vector.putPrimitive(bkey, value); iterations.putPrimitive(bkey, (double)iteration); } return 0d; } /** * return the value of a coordinate. */ public double getCoordinate(String key) { delazifyCoordinate(key); return vector.getPrimitive(key); } /** * the dimension of the vector. */ public int size() { delazify(); return vector.size(); } /** * whether this vector has a non-zero value for a coordinate. */ public boolean containsKey(String key) { delazify(); return vector.containsKey(key); } /** * whether this vector has a non-zero value for a coordinate. */ public boolean contains(String key) { return containsKey(key); } /** * the set of non-zero coordinate names. */ public Set keySet() { delazify(); return vector.keySet(); } /** * the set of values in the map. */ public Set values() { delazify(); return vector.values(); } /** * Apply an arbitrary scalar function to the values. */ public void transformValues(TDoubleFunction func) { delazify(); vector.transformValues(func); } /** * Remove zeros that may have appeared as a result of a transform */ public void removeZeroCoordinates() { for (TObjectDoubleIterator it = vector.troveIterator(); it .hasNext();) { it.advance(); if (Utilities.floatingPointEquals(it.value(), 0d)) { iterations.removePrimitive(it.key()); it.remove(); } } } /** * compute the inner product between this and vec. */ public double dot(StringKeyedVector skv) { if (skv instanceof LazyVector) { return dotWithLazy((LazyVector)skv); } else { return dotWithSKV(skv); } } protected double dotWithSKV(StringKeyedVector vec) { // dont figure out which ones bigger etc, since delazifying this to get // the size is too slow. double res = 0.0; for (TObjectDoubleIterator it = vec.vector.troveIterator(); it .hasNext();) { it.advance(); res += it.value() * delazifyCoordinate(it.key()); } return res; } protected double dotWithLazy(LazyVector vec) { ByteArrayDoubleHashMap vec_small = this.size() > vec.size() ? vec.vector : this.vector; ByteArrayDoubleHashMap vec_big = this.size() > vec.size() ? this.vector : vec.vector; ArrayList commonCoordinates = new ArrayList(); // prevent // modification // during // iteration. double res = 0.0; for (TObjectDoubleIterator it = vec_small.troveIterator(); it .hasNext();) { it.advance(); if (vec_big.containsKey(it.key())) { commonCoordinates.add(it.key()); } } for (byte[] key : commonCoordinates) { delazifyCoordinate(key); vec.delazifyCoordinate(key); res += vec_small.getPrimitive(key) * vec_big.getPrimitive(key); } return res; } /** * compute the LP norm for given p < infinity. */ public double LPNorm(double p) { delazify(); return super.LPNorm(p); } /** * immutable access the underlying hash map. */ public Iterator> iterator() { delazify(); return vector.iterator(); } public String toString() { delazify(); return super.toString(); } private Object writeReplace() throws java.io.ObjectStreamException { delazify(); return this; } // - java serialization private void writeObject(ObjectOutputStream output) throws IOException { output.writeLong(iteration); output.writeObject(vector); output.writeObject(updater); output.writeBoolean(freezeKeySet); } private void readObject(ObjectInputStream input) throws IOException, ClassNotFoundException { iteration = input.readLong(); vector = (ByteArrayDoubleHashMap)input.readObject(); updater = (UpdateFunction)input.readObject(); freezeKeySet = input.readBoolean(); // set up iteration info, iterations = new ByteArrayDoubleHashMap(10, LOAD_FACTOR, (double)iteration); } // - kryo serialization for use in scalding. public void write(Kryo kryo, Output output) { delazify(); output.writeLong(iteration); kryo.writeObject(output, vector); kryo.writeClassAndObject(output, updater); output.writeBoolean(freezeKeySet); } public void read(Kryo kryo, Input input) { iteration = input.readLong(); vector = kryo.readObject(input, ByteArrayDoubleHashMap.class); updater = (UpdateFunction)kryo.readClassAndObject(input); freezeKeySet = input.readBoolean(); // set up iteration info, iterations = new ByteArrayDoubleHashMap(10, LOAD_FACTOR, (double)iteration); } } ================================================ FILE: src/main/java/com/etsy/conjecture/data/MulticlassLabel.java ================================================ package com.etsy.conjecture.data; /** * representing a 100% probability of membership in a particular class */ public class MulticlassLabel extends Label { private static final long serialVersionUID = 1L; protected String label; public MulticlassLabel() { this(null); } public MulticlassLabel(String label) { this.label = label; } public String getLabel() { return this.label; } public void setLabel(String label) { this.label = label; } public String toString() { return label; } public BinaryLabel toBinaryLabel(String className) { return new BinaryLabel(className.equals(label) ? 1.0 : 0.0); } @Override public int hashCode() { final int prime = 31; int result = 1; result = prime * result + ((label == null) ? 0 : label.hashCode()); return result; } @Override public boolean equals(Object obj) { if (this == obj) return true; if (obj == null) return false; if (getClass() != obj.getClass()) return false; MulticlassLabel other = (MulticlassLabel)obj; if (label == null) { if (other.label != null) return false; } else if (!label.equals(other.label)) return false; return true; } } ================================================ FILE: src/main/java/com/etsy/conjecture/data/MulticlassLabeledInstance.java ================================================ package com.etsy.conjecture.data; import java.util.Map; public class MulticlassLabeledInstance extends AbstractInstance implements LabeledInstance { protected MulticlassLabel label; public MulticlassLabel getLabel() { return label; } public MulticlassLabeledInstance(String label) { this(new MulticlassLabel(label), 1.0); } public MulticlassLabeledInstance(String label, double weight) { this(new MulticlassLabel(label), weight); } public MulticlassLabeledInstance(String label, Map instance) { this(new MulticlassLabel(label), instance, 1.0); } public MulticlassLabeledInstance(String label, Map instance, double weight) { this(new MulticlassLabel(label), instance, weight); } public MulticlassLabeledInstance(String label, StringKeyedVector vec) { this(new MulticlassLabel(label), vec.getMap(), 1.0); } public MulticlassLabeledInstance(String label, StringKeyedVector vec, double weight) { this(new MulticlassLabel(label), vec.getMap(), weight); } public MulticlassLabeledInstance(MulticlassLabel label) { this(label, 1.0); } public MulticlassLabeledInstance(MulticlassLabel label, double weight) { super(weight); this.label = label; } public MulticlassLabeledInstance(MulticlassLabel label, Map instance) { this(label, instance, 1.0); } public MulticlassLabeledInstance(MulticlassLabel label, Map instance, double weight) { super(instance, weight); this.label = label; } public MulticlassLabeledInstance(MulticlassLabel label, StringKeyedVector vec) { this(label, vec.getMap(), 1.0); } public MulticlassLabeledInstance(MulticlassLabel label, StringKeyedVector vec, double weight) { this(label, vec.getMap(), weight); } public BinaryLabeledInstance toBinaryInstance(String category) { double tmpLabel = 0d; if (category.equals(this.label.getLabel())) { tmpLabel = 1d; } return new BinaryLabeledInstance(tmpLabel, getVector()); } } ================================================ FILE: src/main/java/com/etsy/conjecture/data/MulticlassPrediction.java ================================================ package com.etsy.conjecture.data; import java.util.Map; import com.google.common.collect.Maps; /** * representing a probability of membership in each class */ public class MulticlassPrediction extends MulticlassLabel { private static final long serialVersionUID = -1L; /** * class membership probabilities */ private Map classProbs; public MulticlassPrediction(Map classProbs) { this.classProbs = Maps.newHashMap(classProbs); boolean first = true; double maxProb = 0; String maxCategory = null; for (String key : classProbs.keySet()) { if (first || classProbs.get(key) > maxProb) { maxProb = classProbs.get(key); maxCategory = key; first = false; } } setLabel(maxCategory); } public Double getProb(String category) { return classProbs.get(category); } public Double getProbOrElse(String category, Double def) { if (classProbs.containsKey(category)) { return classProbs.get(category); } else { return def; } } public Map getMap() { return classProbs; } } ================================================ FILE: src/main/java/com/etsy/conjecture/data/RealValueLabeledInstance.java ================================================ package com.etsy.conjecture.data; import java.util.Map; public class RealValueLabeledInstance extends AbstractInstance implements LabeledInstance { private final RealValuedLabel label; public RealValuedLabel getLabel() { return label; } public RealValueLabeledInstance() { this(0.0); } public RealValueLabeledInstance(RealValuedLabel label) { this(label, 1.0); } public RealValueLabeledInstance(RealValuedLabel label, double weight) { super(weight); this.label = label; } public RealValueLabeledInstance(double label) { this(new RealValuedLabel(label), 1.0); } public RealValueLabeledInstance(double label, double weight) { this(new RealValuedLabel(label), weight); } public RealValueLabeledInstance(double label, Map instance) { this(new RealValuedLabel(label), instance, 1.0); } public RealValueLabeledInstance(double label, Map instance, double weight) { this(new RealValuedLabel(label), instance, weight); } public RealValueLabeledInstance(double label, StringKeyedVector vec) { this(new RealValuedLabel(label), vec.getMap(), 1.0); } public RealValueLabeledInstance(double label, StringKeyedVector vec, double weight) { this(new RealValuedLabel(label), vec.getMap(), weight); } public RealValueLabeledInstance(RealValuedLabel label, Map instance) { this(label, instance, 1.0); } public RealValueLabeledInstance(RealValuedLabel label, Map instance, double weight) { super(instance, weight); this.label = label; } public RealValueLabeledInstance(RealValuedLabel label, StringKeyedVector vec) { this(label, vec, 1.0); } public RealValueLabeledInstance(RealValuedLabel label, StringKeyedVector vec, double weight) { super(vec.getMap(), weight); this.label = label; } } ================================================ FILE: src/main/java/com/etsy/conjecture/data/RealValuedLabel.java ================================================ package com.etsy.conjecture.data; public class RealValuedLabel extends Label { protected final Double value; private static final long serialVersionUID = -1L; public RealValuedLabel(double value) { this.value = new Double(value); } public Double getValue() { return this.value; } @Override public String toString() { return value + ""; } } ================================================ FILE: src/main/java/com/etsy/conjecture/data/Recommendation.java ================================================ package com.etsy.conjecture.data; import java.io.Serializable; public class Recommendation implements Serializable { private static final long serialVersionUID = 1L; public final double score; public final String id; public Recommendation(String id, double score) { this.id = id; this.score = score; } } ================================================ FILE: src/main/java/com/etsy/conjecture/data/StringKeyedVector.java ================================================ package com.etsy.conjecture.data; import gnu.trove.function.TDoubleFunction; import gnu.trove.iterator.TObjectDoubleIterator; import java.io.Serializable; import java.util.Iterator; import java.util.Map; import java.util.Set; import com.etsy.conjecture.Utilities; import com.google.gson.Gson; public class StringKeyedVector implements Serializable, Iterable> { private static final long serialVersionUID = -7070522686694887436L; // - represent the sparse vector by a mapping of coordinate name strings // (feature names) // to doubles. protected ByteArrayDoubleHashMap vector; // - whether to permit the addition of more features to this vector. protected boolean freezeKeySet = false; // - the load factor for the underlying hashmap. public static final float LOAD_FACTOR = 0.9f; public static final String FEATURE_ENCODING = "ASCII"; public StringKeyedVector() { this(10); } public StringKeyedVector(int initialCapacity) { vector = new ByteArrayDoubleHashMap(initialCapacity, LOAD_FACTOR, FEATURE_ENCODING, 0.0); } public StringKeyedVector(StringKeyedVector skv) { this(skv.size()); add(skv); } public StringKeyedVector(Map jmap) { vector = new ByteArrayDoubleHashMap(jmap.size(), LOAD_FACTOR, FEATURE_ENCODING, 0.0); vector.putAll(jmap); } /** * returns whether the key set is frozen (true means that further dimensions * cannot be added to this vector). */ public boolean getFreezeKeySet() { return freezeKeySet; } /** * sets whether the key set is frozen (true means that further dimensions * cannot be added to this vector). */ public void setFreezeKeySet(boolean freeze) { freezeKeySet = freeze; } /** * disregards prior value at a particular key, replacing with the specified * value. */ public double setCoordinate(String key, double value) { if (Utilities.floatingPointEquals(value, 0d)) { return deleteCoordinate(key); } else if (!freezeKeySet) { vector.putPrimitive(key, value); } return 0d; } /** * remove a coordinate from the vector (same as setting it to 0). */ public double deleteCoordinate(String key) { if (vector.containsKey(key) && !freezeKeySet) { return vector.removePrimitive(key); } else { return 0d; } } public Map getMap() { return vector; } /** * add to a specified coordinate (treating it as 0 if it was not present). */ public double addToCoordinate(String key, double value) { byte[] bkey = vector.stringToByteArray(key); return addToCoordinateInternal(bkey, value); } protected double addToCoordinateInternal(byte[] bkey, double value) { if (vector.containsKey(bkey)) { double updated = vector.getPrimitive(bkey) + value; if (Utilities.floatingPointEquals(updated, 0.0d)) { return vector.removePrimitive(bkey); } else { return vector.putPrimitive(bkey, updated); } } else if (!freezeKeySet && !Utilities.floatingPointEquals(value, 0.0d)) { vector.putPrimitive(bkey, value); } return 0d; } /** * return the value of a coordinate. */ public double getCoordinate(String key) { return vector.getPrimitive(key); } /** * add a multiple of vec to this. */ public void addScaled(StringKeyedVector vec, double scale) { if (vec instanceof LazyVector) { ((LazyVector)vec).delazify(); } for (TObjectDoubleIterator it = vec.vector.troveIterator(); it .hasNext();) { it.advance(); addToCoordinateInternal(it.key(), scale * it.value()); } } public StringKeyedVector multiplyPointwise(StringKeyedVector vec) { StringKeyedVector res = new StringKeyedVector(); if (vec instanceof LazyVector) { ((LazyVector)vec).delazify(); } for (TObjectDoubleIterator it = vec.vector.troveIterator(); it .hasNext();) { it.advance(); res.vector.putPrimitive(it.key(), vector.getPrimitive(it.key()) * it.value()); } return res; } public StringKeyedVector projectOntoNonZeroCoordinates(StringKeyedVector vec) { StringKeyedVector res = new StringKeyedVector(); if (vec instanceof LazyVector) { ((LazyVector)vec).delazify(); } for (TObjectDoubleIterator it = vec.vector.troveIterator(); it .hasNext();) { it.advance(); res.addToCoordinateInternal(it.key(), vector.getPrimitive(it.key())); } return res; } /** * the dimension of the vector. */ public int size() { return vector.size(); } /** * whether this vector has a non-zero value for a coordinate. */ public boolean containsKey(String key) { return vector.containsKey(key); } /** * whether this vector has a non-zero value for a coordinate. */ public boolean contains(String key) { return containsKey(key); } /** * the set of non-zero coordinate names. */ public Set keySet() { return vector.keySet(); } /** * the set of values in the map. */ public Set values() { return vector.values(); } /** * add vec to this */ public void add(StringKeyedVector vec) { addScaled(vec, 1.0); } /** * subtract vec from this. */ public void sub(StringKeyedVector vec) { addScaled(vec, -1.0); } /** * multiply this vector by a scalar. */ public void mul(final double a) { transformValues(new TDoubleFunction() { public double execute(double b) { return a * b; } }); } /** * Apply an arbitrary scalar function to the values. */ public void transformValues(TDoubleFunction func) { vector.transformValues(func); } /** * Remove zeros that may have appeared as a result of a transform */ public void removeZeroCoordinates() { @SuppressWarnings("unused") int i = 0; for (TObjectDoubleIterator it = vector.troveIterator(); it .hasNext();) { it.advance(); if (Utilities.floatingPointEquals(it.value(), 0d)) { i++; it.remove(); } } } /** * compute the inner product between this and vec. */ public double dot(StringKeyedVector vec) { if (vec instanceof LazyVector) { return vec.dot(this); } ByteArrayDoubleHashMap vec_small = this.size() > vec.size() ? vec.vector : this.vector; ByteArrayDoubleHashMap vec_big = this.size() > vec.size() ? this.vector : vec.vector; double res = 0.0; for (TObjectDoubleIterator it = vec_small.troveIterator(); it .hasNext();) { it.advance(); if (vec_big.containsKey(it.key())) { res += it.value() * vec_big.getPrimitive(it.key()); } } return res; } /** * compute the LP norm for given p < infinity. */ public double LPNorm(double p) { double tot = 0d; for (double v : vector.values()) { tot += Math.pow(Math.abs(v), p); } return Math.pow(tot, 1d / p); } /** * Find the max value. */ public double max() { double max = 0.0; for (double v : vector.values()) { if (v > max) { max = v; } } return max; } /** * immutable access the underlying hash map. */ public Iterator> iterator() { return vector.iterator(); } public String toString() { Gson gson = new Gson(); return gson.toJson(vector); } /** * performs a deep copy of a stringkeyedvector * */ public StringKeyedVector copy() { StringKeyedVector out = new StringKeyedVector(this.size()); Iterator> it = this.iterator(); while (it.hasNext()) { Map.Entry entry = it.next(); String key = entry.getKey(); Double value = entry.getValue(); out.setCoordinate(key, value); } return out; } } ================================================ FILE: src/main/java/com/etsy/conjecture/evaluation/BinaryModelEvaluation.java ================================================ package com.etsy.conjecture.evaluation; import java.io.Serializable; import java.util.HashMap; import java.util.Map; import java.util.SortedMap; import java.util.TreeMap; import com.etsy.conjecture.data.BinaryLabel; import com.etsy.conjecture.PrimitivePair; /** * a basic container for evaluations TODO: add getters for individual metrics */ public class BinaryModelEvaluation implements ModelEvaluation, Serializable { private static final long serialVersionUID = 1L; private final ReceiverOperatingCharacteristic ROC; private final ConfusionMatrix conf; public BinaryModelEvaluation() { ROC = new ReceiverOperatingCharacteristic(); conf = new ConfusionMatrix(2); } public void merge(ModelEvaluation other) { BinaryModelEvaluation tempOther = (BinaryModelEvaluation) other; ROC.add(tempOther.ROC); conf.add(tempOther.conf); } public void add(BinaryLabel real, BinaryLabel pred) { add(real.getValue(), pred.getValue()); } public void add(double label, double prediction) { ROC.add(label, prediction); conf.addHard((int)label, prediction); } public void add(PrimitivePair labelPrediction) { ROC.add(labelPrediction); conf.addHard((int)labelPrediction.first, labelPrediction.second); } public double computeAUC() { return ROC.binaryAUC(); } public double computeBrier() { return ROC.brierScore(); } public double computeAccy() { return conf.computeAccuracy(); } public double computeAccy(int dim) { return conf.computeAccuracy(dim); } public double computeFmeasure() { return conf.computeAverageFmeasure(); } public double computeFmeasure(int dim) { return conf.computeFmeasure(dim); } public double computePrecision() { return conf.computeAveragePrecision(); } public double computePrecision(int dim) { return conf.computePrecision(dim); } public double computeRecall() { return conf.computeAverageRecall(); } public double computeRecall(int dim) { return conf.computeRecall(dim); } public Map getStatistics() { SortedMap m = new TreeMap(); m.put("Brier", computeBrier()); m.put("Acc (avg)", computeAccy()); m.put("F1 (avg)", computeFmeasure()); m.put("Prc (avg)", computePrecision()); m.put("Rec (avg)", computeRecall()); m.put("0-class Acc", computeAccy(0)); m.put("0-class F1", computeFmeasure(0)); m.put("0-class Prc", computePrecision(0)); m.put("0-class Rec", computeRecall(0)); m.put("1-class Acc", computeAccy(1)); m.put("1-class F1", computeFmeasure(1)); m.put("1-class Prc", computePrecision(1)); m.put("1-class Rec", computeRecall(1)); m.put("1-class AUC", computeAUC()); return m; } public Map getObjects() { Map m = new HashMap(); m.put("conf", conf.toString()); return m; } } ================================================ FILE: src/main/java/com/etsy/conjecture/evaluation/ConfusionMatrix.java ================================================ package com.etsy.conjecture.evaluation; import java.io.Serializable; import java.util.Collection; import com.etsy.conjecture.PrimitivePair; import static com.google.common.base.Preconditions.checkArgument; /** * class representing a confusion matrix for representing misclassification * errors. * {@link Confusion Matrix} * * @author jattenberg */ public class ConfusionMatrix implements Serializable { private static final long serialVersionUID = 1L; /** * The data structure representing the confusion matrix. rows correspond to * labels, columns to predictions */ private double[][] confMatrix; /** The num_classes represented in the confusion matrix */ private final int numClasses; /** The number of label / prediction pairs observed */ double obs; /** * Instantiates a new confusion matrix. * * @param classes * the number of target classes in the problem being considered */ public ConfusionMatrix(int classes) { obs = 0; this.numClasses = classes; this.confMatrix = new double[numClasses][numClasses]; } public void add(ConfusionMatrix m) { obs += m.obs; for (int i = 0; i < numClasses; i++) { for (int j = 0; j < numClasses; j++) { confMatrix[i][j] += m.confMatrix[i][j]; } } } /** * Instantiates a new confusion matrix and adds some initial data * * @param classes * - the number of target classes in the problem being considered * @param labelsAndPredictions * the labels and predictions */ public ConfusionMatrix(int classes, Collection labelsAndPredictions) { this(classes); for (PrimitivePair p : labelsAndPredictions) addInfo(p.first, p.second); } /** * Instantiates a new confusion matrix and adds some initial data * * @param classes * - the number of target classes in the problem being considered * @param labelsAndPredictions * the labels and predictions */ public ConfusionMatrix(int classes, PrimitivePair[] labelsAndPredictions) { this(classes); for (PrimitivePair p : labelsAndPredictions) addInfo(p.first, p.second); } /** * Instantiates a new confusion matrix and adds some initial data * * @param classes * - the number of target classes in the problem being considered * @param labelsAndPredictions * the labels and predictions */ public ConfusionMatrix(int classes, double[] labels, double[] predictions) { this(classes); checkArgument( labels.length == predictions.length, "labels and predictions must be of the same length! (%s vs %s)", labels.length, predictions.length); for (int i = 0; i < labels.length; i++) { addInfo(labels[i], predictions[i]); } } /** * Adds a label / prediction pair to the confusion matrix * * @param label * the index of the actual class * @param guess * the index of the predicted class */ public void addInfo(int label, int guess) { obs++; this.confMatrix[label][guess]++; } /** * Adds a label / prediction pair to the confusion matrix with soft labels * * @param label * the index of the actual class * @param guess * the predicted distribution over classes. */ public void addInfo(int label, double[] guess) { addInfo(label, guess, 1); } /** * Adds a label / prediction pair to the confusion matrix with soft labels * * @param label * the index of the actual class * @param guess * the predicted distribution over classes. * @param freq * the number of times to consider the input label / prediction * pair */ public void addInfo(int label, double[] guess, double freq) { checkArgument( guess.length == numClasses, "input lenght (%d) must match num classes in confusion matrix (%d) ", guess.length, numClasses); obs += freq; for (int i = 0; i < numClasses; i++) { confMatrix[label][i] += freq * guess[i]; } } /** * Adds a label / prediction pair to the confusion matrix with soft labels * note, only applicable for binary classification (2 class) problems * * @param label * the actual probability of membership in the positive class * @param prediction * the predicted probability of membership in the positive class */ public void addInfo(double label, double prediction) { checkArgument( 2 == numClasses, "num classes in confusion matrix (%d) must be 2 for this method", numClasses); addInfo(new double[] { 1. - label, label }, new double[] { 1. - prediction, prediction }); } /** * Adds a label / prediction pair to the confusion matrix with soft labels * * @param softlabels * actual distribution of target class memberships * @param guess * the predicted distribution of class memberships */ public void addInfo(double[] softlabels, double[] guess) { obs++; for (int i = 0; i < numClasses; i++) { for (int j = 0; j < numClasses; j++) { confMatrix[i][j] += softlabels[i] * guess[j]; } } } /** * Adds a label / prediction pair to the confusion matrix with soft labels * * @param softlabels * actual distribution of target class memberships * @param guess * the predicted distribution of class memberships * @param freq * the number of times to consider this label / prediction pair */ public void addInfo(double[] softlabels, double[] guess, double freq) { obs += freq; for (int i = 0; i < numClasses; i++) { for (int j = 0; j < numClasses; j++) { confMatrix[i][j] += softlabels[i] * guess[j] * freq; } } } /** * Computes the actual distribution over labels * * @return the double[] encoding probabilities in each class. */ public double[] classDistribution() { double[] dists = new double[this.numClasses]; for (int i = 0; i < numClasses; i++) { dists[i] = classDistribution(i); } return dists; } /** * Computes the actual probability of mambership in a particular class * denoted by the input index * * @param num * index of the class of interest * @return the probability of membership in the requested class */ public double classDistribution(int num) { double classSum = 0; double totSum = 0; for (int i = 0; i < numClasses; i++) { for (int j = 0; j < numClasses; j++) { if (i == num) classSum += confMatrix[i][j]; totSum += confMatrix[i][j]; } } return classSum / totSum; } /** * Adds a label / prediction pair to the confusion matrix with hard (most * likely class) labels * * @param softlabels * actual distribution of target class memberships * @param guess * the predicted distribution of class memberships * @param freq * the number of times to consider this label / prediction pair */ public void addHard(double[] softlabels, double[] guess, double weight) { addInfo(softToHard(softlabels), softToHard(guess), weight); } /** * Adds a label / prediction pair to the confusion matrix with hard (most * likely class) labels * * @param softlabels * actual distribution of target class memberships * @param guess * the predicted distribution of class memberships */ public void addHard(double[] softlabels, double[] guess) { addInfo(softToHard(softlabels), softToHard(guess)); } /** * Adds a label / prediction pair to the confusion matrix with hard (most * likely class) labels note, only applicable for binary classification (2 * class) problems * * @param label * the index of the actual class of membership * @param prediction * the predicted probability of membership in the positive class */ public void addHard(int label, double[] guess) { addInfo(label, softToHard(guess)); } /** * Adds a label / prediction pair to the confusion matrix with hard (most * likely class) labels note, only applicable for binary classification (2 * class) problems * * @param label * the index of the actual class of membership * @param prediction * the predicted probability of membership in the positive class */ public void addHard(int label, double prediction) { addInfo(label, softToHard(new double[] { 1.0 - prediction, prediction })); } /** * Adds a label / prediction pair to the confusion matrix with hard (most * likely class) labels note, only applicable for binary classification (2 * class) problems * * @param label * the index of the actual class of membership * @param prediction * the predicted probability of membership in the positive class * @param freq * the number of times this label / prediction pair should be * considered. */ public void addHard(int label, double[] guess, double freq) { addInfo(label, softToHard(guess)); } /** * converts a soft prediction of probability estimates into a categorical * indicator for the most likely class * * @param scores * probabilities of label class membership * @return the categorical values, 0's for all target classes with a 1 for * the most likely class */ private static double[] softToHard(double[] scores) { int maxindex = 0; double max = 0; double[] out = new double[scores.length]; for (int i = 0; i < scores.length; i++) { if (scores[i] > max) { maxindex = i; max = scores[i]; } } out[maxindex] = 1; return out; } /* * (non-Javadoc) * * @see java.lang.Object#toString() */ @Override public String toString() { StringBuilder buff = new StringBuilder(); buff.append("predicted:\t"); for (int i = 0; i < numClasses - 1; i++) { buff.append(i + "\t"); } buff.append((numClasses - 1) + "\n"); for (int i = 0; i < numClasses; i++) { buff.append("actually " + i + ":\t"); for (int j = 0; j < numClasses; j++) { buff.append(String.format("%.4f\t", confMatrix[i][j])); } buff.append("\n"); } return buff.toString(); } /** * To string row normalized (divided by the sum of each row) * * @return the string representation of the confusion matrix that has been * row normalized */ public String toStringRowNormalized() { StringBuilder buff = new StringBuilder(); buff.append("predicted:\t"); for (int i = 0; i < numClasses - 1; i++) { buff.append(i + "\t"); } double[] rowSums = this.rowSums(); buff.append((numClasses - 1) + "\n"); for (int i = 0; i < numClasses; i++) { buff.append("actually " + i + ":\t"); for (int j = 0; j < numClasses; j++) { String s = String.format("%.4f\t", confMatrix[i][j] / rowSums[i]); buff.append(s); } buff.append("\n"); } return buff.toString(); } /** * To string column normalized (divided by the sum of each column) * * @return the string representation of the confusion matrix that has been * column normalized */ public String toStringColNormalized() { StringBuilder buff = new StringBuilder(); buff.append("predicted:\t"); for (int i = 0; i < numClasses - 1; i++) { buff.append(i + "\t"); } double[] colSums = this.colSums(); buff.append((numClasses - 1) + "\n"); for (int i = 0; i < numClasses; i++) { buff.append("actually " + i + ":\t"); for (int j = 0; j < numClasses; j++) { String s = String.format("%.4f\t", confMatrix[i][j] / colSums[i]); buff.append(s); } buff.append("\n"); } return buff.toString(); } /** * Compute the sum of each row * * @return an array containing the sum of each row. */ public double[] rowSums() { double[] sums = new double[numClasses]; for (int i = 0; i < numClasses; i++) { for (int j = 0; j < numClasses; j++) { sums[i] += confMatrix[i][j]; } } return sums; } /** * Compute the accuracy for a given class; the % of examples that have been * correctly classifieed. * * @param classid * the index of the class where accuracy has been requested * @return the % of correctly classified examples for the requested class */ public double computeAccuracy(int classid) { double tn = 0.; for (int i = 0; i < numClasses; i++) for (int j = 0; j < numClasses; j++) if (j != classid && i != classid) tn += confMatrix[i][j]; double tp = confMatrix[classid][classid]; return (tn + tp) / obs; } public double computeAverageFmeasure() { double[] rowSums = rowSums(); double total = total(rowSums); double fmeasure = 0.; for (int i = 0; i < numClasses; i++) { fmeasure += rowSums[i] * computeFmeasure(i); } return fmeasure / total; } public double computeAveragePrecision() { double[] rowSums = rowSums(); double total = total(rowSums); double precision = 0.; for (int i = 0; i < numClasses; i++) { precision += rowSums[i] * computePrecision(i); } return precision / total; } public double computeAverageRecall() { double[] rowSums = rowSums(); double total = total(rowSums); double recall = 0.; for (int i = 0; i < numClasses; i++) { recall += rowSums[i] * computeRecall(i); } return recall / total; } /** * Compute the sums of each column * * @return an array containing the sum of each column. */ public double[] colSums() { double[] sums = new double[numClasses]; for (int i = 0; i < numClasses; i++) { for (int j = 0; j < numClasses; j++) { sums[j] += confMatrix[i][j]; } } return sums; } /** * Return the confusion matrix as a 2d array * * @return the confusion matrix data structure */ public double[][] getMatrix() { double[][] out = new double[numClasses][numClasses]; for (int i = 0; i < numClasses; i++) for (int j = 0; j < numClasses; j++) out[i][j] = confMatrix[i][j]; return out; } /** * Gets the number of classes in the confusion matrix * * @return the number of classes considered */ public int getDim() { return this.numClasses; } /** * Computes the accuracy over all observations for all classes (% of * correctly labeled examples). * * @return accuracy over all classes. */ public double computeAccuracy() { double accy = 0; double tot = 0; double right = 0; for (int i = 0; i < this.numClasses; i++) { tot += total(this.confMatrix[i]); right += this.confMatrix[i][i]; } if (tot > 0) { accy = right / tot; } return accy; } /** * Compute the precision for each class; the % of members of labeled as * belonging to each class who were actually members of that class * * @return an array containing the precision values for each class. */ public double[] computePrecision() { double[] precision = new double[this.numClasses]; for (int i = 0; i < this.numClasses; i++) { double yes = 0; double no = 0; for (int j = 0; j < this.numClasses; j++) { if (i == j) yes += confMatrix[i][j]; else no += confMatrix[i][j]; } if (yes + no != 0) precision[i] = yes / (yes + no); } return precision; } /** * Compute the recall for each class; the % of members of belonging to each * class that were labeled as class members * * @return an array containing the recall values for each class. */ public double[] computeRecall() { double[] recall = new double[this.numClasses]; double yes[] = new double[this.numClasses]; double no[] = new double[this.numClasses]; for (int i = 0; i < this.numClasses; i++) { for (int j = 0; j < this.numClasses; j++) { if (i == j) yes[j] += confMatrix[i][j]; else no[j] += confMatrix[i][j]; } } for (int i = 0; i < numClasses; i++) { if (yes[i] + no[i] != 0) recall[i] = yes[i] / (yes[i] + no[i]); } return recall; } /** * Computes the F-measure for each class; the harmonic mean of precision and * recall * {@link F-Measure} * for more info * * @return the array containing the F-measure for each class */ public double[] computeFmeasure() { double[] fmeasure = new double[numClasses]; double[] precision = this.computePrecision(); double[] recall = this.computeRecall(); for (int i = 0; i < this.numClasses; i++) { if (recall[i] + precision[i] != 0) fmeasure[i] = 2.0 * (precision[i] * recall[i]) / (precision[i] + recall[i]); } return fmeasure; } /** * Builds a string table containing the common IR measures, precision, * recall, and F measure for each class * * @return the string with performance stats */ public String getIR() { StringBuffer buff = new StringBuffer(); buff.append("class\t" + "precision\t" + "recall\t" + "F measure\n"); double[] precision = this.computePrecision(); double[] recall = this.computeRecall(); double[] fmeasure = this.computeFmeasure(); for (int i = 0; i < numClasses; i++) { buff.append(i + "\t" + precision[i] + "\t" + recall[i] + "\t" + fmeasure[i] + "\n"); } return buff.toString(); } /** * Computes precision for a given class; the % of members of belonging to * each class that were labeled as class members * * @param dim * class of interest * @return the precision for the requested class */ public double computePrecision(int dim) { double tot = 0; for (int i = 0; i < numClasses; i++) tot += confMatrix[i][dim]; return confMatrix[dim][dim] / tot; } /** * Compute the recall for a given class; the % of members of belonging to * each class that were labeled as class members * * @param dim * the class of interest * @return the recall for the requested class */ public double computeRecall(int dim) { double tot = 0; for (int i = 0; i < numClasses; i++) tot += confMatrix[dim][i]; return confMatrix[dim][dim] / tot; } /** * Computes the F-measure for a given class; the harmonic mean of precision * and recall * {@link F-Measure} * for more info * * * @param dim * the class of interest * @return the F-Measure of the requested class */ public double computeFmeasure(int dim) { double pre = computePrecision(dim); double rec = computeRecall(dim); return 2 * (pre * rec) / (pre + rec); } /** * Total. * * @param arr * the arr * @return the double */ private double total(double[] arr) { double total = 0; for (int i = 0; i < arr.length; i++) total += arr[i]; return total; } /** * Builds a confusion matrix with the input observations and computes the * accuracy over all observations for all classes (% of correctly labeled * examples). * * * @param input * the input label / prediction pairs * @return the accuracy of the input values */ public static double computeAccuracy(Collection input) { ConfusionMatrix conf = new ConfusionMatrix(2); for (PrimitivePair p : input) conf.addInfo(new double[] { 1. - p.first, p.first }, new double[] { 1. - p.second, p.second }); return conf.computeAccuracy(); } } ================================================ FILE: src/main/java/com/etsy/conjecture/evaluation/EvaluationAggregator.java ================================================ package com.etsy.conjecture.evaluation; import java.io.Serializable; import java.util.ArrayList; import java.util.HashMap; import java.util.List; import java.util.Map; import java.util.TreeMap; import org.apache.commons.math3.stat.descriptive.DescriptiveStatistics; import com.etsy.conjecture.data.Label; public class EvaluationAggregator implements Serializable { private static final long serialVersionUID = 5825037849957449364L; protected Map stats = new TreeMap(); protected Map> obj = new HashMap>(); public void add(ModelEvaluation eval) { Map fold = eval.getStatistics(); if (!stats.isEmpty()) { if (!fold.keySet().equals(stats.keySet())) { throw new java.lang.RuntimeException( "Tried to add incompatible folds, with fields:" + fold.keySet().toString() + " and " + stats.keySet().toString()); } for (Map.Entry e : fold.entrySet()) { stats.get(e.getKey()).addValue(e.getValue()); } for (Map.Entry e : eval.getObjects().entrySet()) { obj.get(e.getKey()).add(e.getValue()); } } else { for (Map.Entry e : fold.entrySet()) { DescriptiveStatistics ds = new DescriptiveStatistics(); ds.addValue(e.getValue()); stats.put(e.getKey(), ds); } for (Map.Entry e : eval.getObjects().entrySet()) { obj.put(e.getKey(), new ArrayList(5)); obj.get(e.getKey()).add(e.getValue()); } } } public double getValue(String key) { return stats.get(key).getMean(); } @Override public String toString() { StringBuilder buff = new StringBuilder("Stat:\tMean\tStdDev\tMedian\n"); for (Map.Entry e : stats.entrySet()) { buff.append(e.getKey() + ":\t" + format(e.getValue()) + "\n"); } for (Map.Entry> e : obj.entrySet()) { buff.append(e.getKey()).append(":\n"); for (Object o : e.getValue()) { buff.append("----\n").append(o.toString()).append("\n"); } } return buff.toString(); } private String format(DescriptiveStatistics stats) { return String.format("%.4f\t%.4f\t%.4f", stats.getMean(), stats.getStandardDeviation(), stats.getPercentile(50)); } } ================================================ FILE: src/main/java/com/etsy/conjecture/evaluation/ModelEvaluation.java ================================================ package com.etsy.conjecture.evaluation; import com.etsy.conjecture.data.Label; import java.util.Map; public interface ModelEvaluation { public void add(L real, L predicted); public Map getStatistics(); public Map getObjects(); public void merge(ModelEvaluation other); } ================================================ FILE: src/main/java/com/etsy/conjecture/evaluation/MulticlassConfusionMatrix.java ================================================ package com.etsy.conjecture.evaluation; import java.io.Serializable; import java.util.Map; import java.util.Set; import java.util.SortedMap; import java.util.TreeMap; /** * class representing a confusion matrix for representing misclassification * errors. * {@link Confusion Matrix} * * @author jattenberg */ public class MulticlassConfusionMatrix implements Serializable { private static final long serialVersionUID = 1L; /** * The data structure representing the confusion matrix. rows correspond to * labels, columns to predictions */ private final SortedMap> confusionMatrix; /** The num_classes represented in the confusion matrix */ private final int numClasses; /** The number of label / prediction pairs observed */ double obs; /** * Instantiates a new confusion matrix. * * @param classes * the number of target classes in the problem being considered */ public MulticlassConfusionMatrix(String[] categories) { obs = 0; this.numClasses = categories.length; confusionMatrix = initializeMatrix(categories); } public void add(MulticlassConfusionMatrix m) { obs += m.obs; for(Map.Entry> entry : m.confusionMatrix.entrySet()) { String label = entry.getKey(); SortedMap value = entry.getValue(); for(Map.Entry inner_entry : value.entrySet()) { String inner_label = inner_entry.getKey(); Double update = inner_entry.getValue(); confusionMatrix.get(label).put(inner_label, update + getValue(label, inner_label)); } } } private static SortedMap> initializeMatrix( Set categories) { String[] catArray = new String[categories.size()]; int ct = 0; for (String category : categories) { catArray[ct++] = category; } return initializeMatrix(catArray); } private static SortedMap> initializeMatrix( String[] categories) { SortedMap> conf = new TreeMap>(); for (String categoryOuter : categories) { conf.put(categoryOuter, new TreeMap()); for (String categoryInner : categories) { conf.get(categoryOuter).put(categoryInner, 0d); } } return conf; } private Double getValue(String label, String guess) { return confusionMatrix.get(label).get(guess); } private void updateConfusionMatrix(String label, String guess, double value) { confusionMatrix.get(label).put(guess, value + getValue(label, guess)); } private Map initializeProbabilityMatrix() { Map probs = new TreeMap(); for (String category : confusionMatrix.keySet()) { probs.put(category, 0d); } return probs; } /** * Adds a label / prediction pair to the confusion matrix * * @param label * the index of the actual class * @param guess * the index of the predicted class */ public void addInfo(String label, String guess) { obs++; updateConfusionMatrix(label, guess, 1d); } public void addInfo(String label, String guess, double freq) { obs += freq; updateConfusionMatrix(label, guess, freq); } /** * Adds a label / prediction pair to the confusion matrix with soft labels * * @param label * the index of the actual class * @param guess * the predicted distribution over classes. */ public void addInfo(String label, Map guesses) { addInfo(label, guesses, 1d); } /** * Adds a label / prediction pair to the confusion matrix with soft labels * * @param label * the index of the actual class * @param guess * the predicted distribution over classes. * @param freq * the number of times to consider the input label / prediction * pair */ public void addInfo(String label, Map predictions, double freq) { // TODO: ensure that sets match obs += freq; for (String category : predictions.keySet()) { updateConfusionMatrix(label, category, predictions.get(category)); } } /** * Adds a label / prediction pair to the confusion matrix with soft labels * * @param softlabels * actual distribution of target class memberships * @param guess * the predicted distribution of class memberships */ public void addInfo(Map labels, Map predictions) { addInfo(labels, predictions, 1d); } /** * Adds a label / prediction pair to the confusion matrix with soft labels * * @param softlabels * actual distribution of target class memberships * @param guess * the predicted distribution of class memberships * @param freq * the number of times to consider this label / prediction pair */ public void addInfo(Map labels, Map predictions, double freq) { obs += freq; for (String categoryLabel : labels.keySet()) { for (String categoryGuess : predictions.keySet()) { updateConfusionMatrix(categoryLabel, categoryGuess, freq); } } } /** * Computes the actual distribution over labels * * @return the Map classDistribution() { Map dists = initializeProbabilityMatrix(); for (String category : dists.keySet()) { dists.put(category, classDistribution(category)); } return dists; } /** * Computes the actual probability of mambership in a particular class * denoted by the input index TODO: implement this more efficiently * * @param num * index of the class of interest * @return the probability of membership in the requested class */ public double classDistribution(String category) { double classSum = 0; double totSum = 0; for (String categoryLabel : confusionMatrix.keySet()) { for (String categoryPrediction : confusionMatrix.keySet()) { if (categoryPrediction.equals(categoryLabel)) { classSum += getValue(categoryLabel, categoryPrediction); } totSum += getValue(categoryLabel, categoryPrediction); } } return totSum > 0d ? classSum : 0d; } /** * Adds a label / prediction pair to the confusion matrix with hard (most * likely class) labels * * @param softlabels * actual distribution of target class memberships * @param guess * the predicted distribution of class memberships * @param freq * the number of times to consider this label / prediction pair */ public void addHard(Map softlabels, Map predictions, double weight) { addInfo(softToHard(softlabels), softToHard(predictions), weight); } /** * Adds a label / prediction pair to the confusion matrix with hard (most * likely class) labels * * @param softlabels * actual distribution of target class memberships * @param guess * the predicted distribution of class memberships */ public void addHard(Map softlabels, Map predictions) { addInfo(softToHard(softlabels), softToHard(predictions), 1d); } /** * Adds a label / prediction pair to the confusion matrix with hard (most * likely class) labels note, only applicable for binary classification (2 * class) problems * * @param label * the index of the actual class of membership * @param prediction * the predicted probability of membership in the positive class */ public void addHard(String label, Map guess) { addInfo(label, softToHard(guess), 1d); } /** * Adds a label / prediction pair to the confusion matrix with hard (most * likely class) labels note, only applicable for binary classification (2 * class) problems * * @param label * the index of the actual class of membership * @param prediction * the predicted probability of membership in the positive class * @param freq * the number of times this label / prediction pair should be * considered. */ public void addHard(String label, Map guess, double freq) { addInfo(label, softToHard(guess), 1d); } /** * converts a soft prediction of probability estimates into a categorical * indicator for the most likely class * * @param scores * probabilities of label class membership * @return the categorical values, 0's for all target classes with a 1 for * the most likely class */ private static String softToHard(Map scores) { String maxindex = null; double max = Double.NEGATIVE_INFINITY; for (String category : scores.keySet()) { if (scores.get(category) > max) { max = scores.get(category); maxindex = category; } } return maxindex; } public String printDebug() { return ""; } /* * (non-Javadoc) * * @see java.lang.Object#toString() */ public String toString() { StringBuilder buff = new StringBuilder(); buff.append("predicted:\t"); for (String category : confusionMatrix.keySet()) { buff.append(category + "\t"); } buff.append("\n"); for (String categoryLabel : confusionMatrix.keySet()) { buff.append("actually " + categoryLabel + ":\t"); for (String categoryPrediction : confusionMatrix.keySet()) { buff.append(String.format("%.4f\t", getValue(categoryLabel, categoryPrediction))); } buff.append("\n"); } return buff.toString(); } /** * To string row normalized (divided by the sum of each row) * * @return the string representation of the confusion matrix that has been * row normalized */ public String toStringRowNormalized() { StringBuilder buff = new StringBuilder(); buff.append("predicted:\t"); for (String category : confusionMatrix.keySet()) { buff.append(category + "\t"); } Map rowSums = rowSums(); buff.append("\n"); for (String categoryLabel : confusionMatrix.keySet()) { buff.append("actually " + categoryLabel + ":\t"); for (String categoryPrediction : confusionMatrix.keySet()) { String s = String.format( "%.4f\t", getValue(categoryLabel, categoryPrediction) / rowSums.get(categoryLabel)); buff.append(s); } buff.append("\n"); } return buff.toString(); } /** * To string column normalized (divided by the sum of each column) * * @return the string representation of the confusion matrix that has been * column normalized */ public String toStringColNormalized() { StringBuilder buff = new StringBuilder(); buff.append("predicted:\t"); for (String category : confusionMatrix.keySet()) { buff.append(category + "\t"); } Map colSums = colSums(); buff.append("\n"); for (String categoryLabel : confusionMatrix.keySet()) { buff.append("actually " + categoryLabel + ":\t"); for (String categoryPrediction : confusionMatrix.keySet()) { String s = String.format( "%.4f\t", getValue(categoryLabel, categoryPrediction) / colSums.get(categoryLabel)); buff.append(s); } buff.append("\n"); } return buff.toString(); } /** * Compute the sum of each row * * @return an array containing the sum of each row. */ public Map rowSums() { Map sums = initializeProbabilityMatrix(); for (String cateogryLabel : confusionMatrix.keySet()) { for (String cateogryPrediction : confusionMatrix.keySet()) { sums.put( cateogryLabel, sums.get(cateogryLabel) + getValue(cateogryLabel, cateogryPrediction)); } } return sums; } /** * Compute the accuracy for a given class; the % of examples that have been * correctly classifieed. * * @param classid * the index of the class where accuracy has been requested * @return the % of correctly classified examples for the requested class */ public double computeAccuracy(String classId) { double tn = 0.; for (String categoryLabel : confusionMatrix.keySet()) { for (String categoryPrediction : confusionMatrix.keySet()) { if (!categoryLabel.equals(classId) && !categoryPrediction.equals(classId)) tn += getValue(categoryLabel, categoryPrediction); } } double tp = getValue(classId, classId); return (tn + tp) / obs; } public double computeAverageFmeasure() { Map rowSums = rowSums(); double total = total(rowSums); double fmeasure = 0.; for (String category : confusionMatrix.keySet()) { fmeasure += rowSums.get(category) * computeFmeasure(category); } return fmeasure / total; } public double computeAveragePrecision() { Map rowSums = rowSums(); double total = total(rowSums); double precision = 0.; for (String category : confusionMatrix.keySet()) { double pre = computePrecision(category); if (!Double.isNaN(pre)) { // when nothing is predicted as category, // pre is NaN precision += rowSums.get(category) * pre; } } return precision / total; } public double computeAverageRecall() { Map rowSums = rowSums(); double total = total(rowSums); double recall = 0.; for (String category : confusionMatrix.keySet()) { double re = computeRecall(category); if (!Double.isNaN(re)) { // re is NaN when there are 0 examples with // label of category recall += rowSums.get(category) * re; } } return recall / total; } /** * Compute the sums of each column * * @return an array containing the sum of each column. */ public Map colSums() { Map sums = initializeProbabilityMatrix(); for (String categoryLabel : confusionMatrix.keySet()) { for (String categoryPrediction : confusionMatrix.keySet()) { sums.put(categoryPrediction, sums.get(categoryPrediction) + getValue(categoryLabel, categoryPrediction)); } } return sums; } /** * Return the confusion matrix as a 2d array * * @return the confusion matrix data structure */ public SortedMap> getMatrix() { SortedMap> out = initializeMatrix(confusionMatrix .keySet()); for (String categoryLabel : confusionMatrix.keySet()) { for (String categoryPrediction : confusionMatrix.keySet()) { out.get(categoryLabel).put(categoryPrediction, getValue(categoryLabel, categoryPrediction)); } } return out; } /** * Gets the number of classes in the confusion matrix * * @return the number of classes considered */ public int getDim() { return this.numClasses; } /** * Computes the accuracy over all observations for all classes (% of * correctly labeled examples). * * @return accuracy over all classes. */ public double computeAccuracy() { double accy = 0d; double tot = 0d; double right = 0d; Map rowSums = rowSums(); for (String category : confusionMatrix.keySet()) { tot += rowSums.get(category); right += getValue(category, category); } if (tot > 0) { accy = right / tot; } return accy; } /** * Compute the precision for each class; the % of members of labeled as * belonging to each class who were actually members of that class * * @return an array containing the precision values for each class. */ public Map computePrecision() { Map precision = initializeProbabilityMatrix(); for (String categoryLabel : confusionMatrix.keySet()) { precision.put(categoryLabel, computePrecision(categoryLabel)); } return precision; } /** * Compute the recall for each class; the % of members of belonging to each * class that were labeled as class members * * @return an array containing the recall values for each class. */ public Map computeRecall() { Map recall = initializeProbabilityMatrix(); for (String categoryLabel : confusionMatrix.keySet()) { recall.put(categoryLabel, computeRecall(categoryLabel)); } return recall; } /** * Computes the F-measure for each class; the harmonic mean of precision and * recall * {@link F-Measure} * for more info * * @return the array containing the F-measure for each class */ public Map computeFmeasure() { Map fmeasure = initializeProbabilityMatrix(); Map precision = this.computePrecision(); Map recall = this.computeRecall(); for (String category : confusionMatrix.keySet()) { if (recall.get(category) + precision.get(category) != 0) fmeasure.put( category, 2.0 * (precision.get(category) * recall .get(category)) / (precision.get(category) + recall .get(category))); } return fmeasure; } /** * Builds A String Table Containing The common IR measures, precision, * recall, and F measure for each class * * @return the string with performance stats */ public String getIR() { StringBuffer buff = new StringBuffer(); buff.append("class\t" + "precision\t" + "recall\t" + "F measure\n"); Map precision = this.computePrecision(); Map recall = this.computeRecall(); Map fmeasure = this.computeFmeasure(); for (String category : confusionMatrix.keySet()) { buff.append(category + "\t" + precision.get(category) + "\t" + recall.get(category) + "\t" + fmeasure.get(category) + "\n"); } return buff.toString(); } /** * Computes precision for a given class; the % of members of belonging to * each class that were labeled as class members * * @param dim * class of interest * @return the precision for the requested class */ public double computePrecision(String category) { double tot = 0; for (String label : confusionMatrix.keySet()) { tot += getValue(label, category); } return getValue(category, category) / tot; } /** * Compute the recall for a given class; the % of members of belonging to * each class that were labeled as class members * * @param dim * the class of interest * @return the recall for the requested class */ public double computeRecall(String category) { double tot = 0; for (String prediction : confusionMatrix.keySet()) tot += getValue(category, prediction); return getValue(category, category) / tot; } /** * Computes the F-measure for a given class; the harmonic mean of precision * and recall * {@link F-Measure} * for more info * * * @param dim * the class of interest * @return the F-Measure of the requested class */ public double computeFmeasure(String category) { double pre = computePrecision(category); double rec = computeRecall(category); return 2 * (pre * rec) / (pre + rec); } /** * Total. * * @param arr * the arr * @return the double */ private double total(Map probs) { double total = 0; for (String category : probs.keySet()) total += probs.get(category); return total; } } ================================================ FILE: src/main/java/com/etsy/conjecture/evaluation/MulticlassModelEvaluation.java ================================================ package com.etsy.conjecture.evaluation; import java.io.Serializable; import java.util.HashMap; import java.util.Map; import java.util.SortedMap; import java.util.TreeMap; import com.etsy.conjecture.GenericPair; import com.etsy.conjecture.data.MulticlassLabel; import com.etsy.conjecture.data.MulticlassPrediction; /** * a basic container for evaluations TODO: add getters for individual metrics */ public class MulticlassModelEvaluation implements Serializable, ModelEvaluation { /** * */ private static final long serialVersionUID = 4916724871985109129L; private final MulticlassReceiverOperatingCharacteristic ROC; private final MulticlassConfusionMatrix conf; private final String[] categories; public MulticlassModelEvaluation(String[] categories) { this.categories = categories; ROC = new MulticlassReceiverOperatingCharacteristic(categories); conf = new MulticlassConfusionMatrix(categories); } public void add(String label, MulticlassPrediction prediction) { ROC.add(label, prediction); conf.addInfo(label, prediction.getLabel()); } public void merge(ModelEvaluation other) { MulticlassModelEvaluation tempOther = (MulticlassModelEvaluation) other; ROC.add(tempOther.ROC); conf.add(tempOther.conf); } public void add(GenericPair labelPrediction) { add(labelPrediction.first, labelPrediction.second); } public void add(MulticlassLabel real, MulticlassLabel pred) { if (!(pred instanceof MulticlassPrediction)) { throw new java.lang.RuntimeException( "MulticlassModelEvaluation needs a MulticlassPrediction"); } add(real.getLabel(), (MulticlassPrediction)pred); } public double computeAUC() { return ROC.multiclassAUC(); } public double computeAUC(String dim) { return ROC.singleClassAUC(dim); } public double computeBrier() { return ROC.multiclassBrierScore(); } public double computeAccy() { return conf.computeAccuracy(); } public double computeAccy(String dim) { return conf.computeAccuracy(dim); } public double computeFmeasure() { return conf.computeAverageFmeasure(); } public double computeFmeasure(String dim) { return conf.computeFmeasure(dim); } public double computePrecision() { return conf.computeAveragePrecision(); } public double computePrecision(String dim) { return conf.computePrecision(dim); } public double computeRecall() { return conf.computeAverageRecall(); } public double computeRecall(String dim) { return conf.computeRecall(dim); } public double computePercent(String dim) { return ROC.computePercent(dim); } public String printDebug() { return conf.printDebug(); } public Map getStatistics() { SortedMap m = new TreeMap(); m.put("AUC (avg)", computeAUC()); m.put("Brier (avg)", computeBrier()); m.put("Acc (avg)", computeAccy()); m.put("F1 (avg)", computeFmeasure()); m.put("Prc (avg)", computePrecision()); m.put("Rec (avg)", computeRecall()); for (String category : categories) { m.put(category + ": Pct", computePercent(category)); m.put(category + ": AUC", computeAUC(category)); m.put(category + ": Acc", computeAccy(category)); m.put(category + ": F1", computeFmeasure(category)); m.put(category + ": Prc", computePrecision(category)); m.put(category + ": Rec", computeRecall(category)); } return m; } public String toString() { StringBuilder buff = new StringBuilder(); buff.append("AUC: " + ROC.multiclassAUC() + "\n"); buff.append("Brier: " + ROC.multiclassBrierScore() + "\n"); buff.append("IR metrics:\n" + conf.getIR() + "\n"); buff.append("Confusion Matrix:\n" + conf.toString() + "\n"); return buff.toString(); } public HashMap getObjects() { HashMap m = new HashMap(); m.put("conf", conf.toString()); return m; } } ================================================ FILE: src/main/java/com/etsy/conjecture/evaluation/MulticlassReceiverOperatingCharacteristic.java ================================================ package com.etsy.conjecture.evaluation; import java.io.Serializable; import java.util.Collection; import java.util.Map; import java.util.HashMap; import com.etsy.conjecture.GenericPair; import com.etsy.conjecture.data.MulticlassPrediction; import static com.google.common.base.Preconditions.checkArgument; public class MulticlassReceiverOperatingCharacteristic implements Serializable { private static final long serialVersionUID = 1L; /** Num examples in each class. */ private Map classCounts; /** Num total examples */ private int numExamples; /** Binary ROCs for each class */ private Map classROC; /** * Instantiates a new receiver operating characteristic. */ public MulticlassReceiverOperatingCharacteristic(String[] categories) { classROC = new HashMap(); classCounts = new HashMap(); for (String category : categories) { classROC.put(category, new ReceiverOperatingCharacteristic()); classCounts.put(category, 0); } } public void add(MulticlassReceiverOperatingCharacteristic other) { numExamples += other.numExamples; for(Map.Entry entry : other.classCounts.entrySet()) { String category = entry.getKey(); Integer count = entry.getValue(); classCounts.put(category, classCounts.get(category)+count); } for(Map.Entry entry : other.classROC.entrySet()) { String category = entry.getKey(); ReceiverOperatingCharacteristic update = entry.getValue(); ReceiverOperatingCharacteristic roc = classROC.get(category); roc.add(update); classROC.put(category, roc); } } public void add(GenericPair labelPrediction) { add(labelPrediction.first, labelPrediction.second); } public void add(String label, MulticlassPrediction prediction) { checkArgument(classCounts.containsKey(label), "label is of unknown category: %s", label); checkArgument(classROC.containsKey(label), "label is of unknown category: %s", label); // accum class counts int count = classCounts.get(label); classCounts.put(label, count + 1); // accum total counts; numExamples++; // add to individual binary ROC classes for (String category : classCounts.keySet()) { double binaryPrediction = prediction.getMap().get(category); double classLabel = category.equals(label) ? 1d : 0d; classROC.get(category).add(classLabel, binaryPrediction); } } public double multiclassAUC() { double weightedAverageAUC = 0d; for (String label : classCounts.keySet()) { double classInfluence = (double)classCounts.get(label) / numExamples; ReceiverOperatingCharacteristic roc = classROC.get(label); double classAUC = roc.binaryAUC(); weightedAverageAUC += classInfluence * classAUC; } return weightedAverageAUC; } public double singleClassAUC(String category) { return classROC.get(category).binaryAUC(); } public double multiclassBrierScore() { double brierScore = 0d; int numClasses = classCounts.keySet().size(); for (String label : classCounts.keySet()) { brierScore += (classROC.get(label)).brierScore(); } return brierScore / numClasses; } public double computePercent(String category) { return classCounts.get(category) / (double) numExamples; } public static double computeAUC( Collection> labelsAndPredictions, String[] categories) { MulticlassReceiverOperatingCharacteristic roc = new MulticlassReceiverOperatingCharacteristic( categories); for (GenericPair p : labelsAndPredictions) roc.add((String)p.first, (MulticlassPrediction)p.second); return roc.multiclassAUC(); } public static double computeBrierScore( Collection> labelsAndPredictions, String[] categories) { MulticlassReceiverOperatingCharacteristic roc = new MulticlassReceiverOperatingCharacteristic( categories); for (GenericPair p : labelsAndPredictions) roc.add(p.first, p.second); return roc.multiclassBrierScore(); } } ================================================ FILE: src/main/java/com/etsy/conjecture/evaluation/ReceiverOperatingCharacteristic.java ================================================ package com.etsy.conjecture.evaluation; import java.io.Serializable; import java.util.ArrayList; import java.util.Collection; import java.util.Comparator; import java.util.List; import java.util.Map; import java.util.NavigableMap; import java.util.TreeMap; import com.etsy.conjecture.PrimitivePair; public class ReceiverOperatingCharacteristic implements Serializable { private static class NumComparator implements Comparator, Serializable { private static final long serialVersionUID = 6569477679353298040L; @Override public int compare(Double o1, Double o2) { return -o1.compareTo(o2); } } private static final long serialVersionUID = 1L; private static NumComparator numComparator = new NumComparator(); /** * map for storing the number of positive and negative labeled examples for * each prediction */ private NavigableMap examples; /** The pos. */ private double pos = 0; /** The neg. */ private double neg = 0; /** * Instantiates a new receiver operating characteristic. */ public ReceiverOperatingCharacteristic() { examples = new TreeMap(numComparator); } /** * Merge two ROCs together. */ public void add(ReceiverOperatingCharacteristic r) { for (Map.Entry entry : r.examples.entrySet()) { increment(entry.getKey(), entry.getValue()); } pos += r.pos; neg += r.neg; } /** * increments count values */ private void increment(Double key, int[] value) { if (!examples.containsKey(key)) { examples.put(key, value); } else { int[] oldVals = examples.get(key); oldVals[0] += value[0]; oldVals[1] += value[1]; examples.put(key, oldVals); } } private void increment(Double prediction, double label) { if (label > 0.5) { pos++; increment(prediction, new int[] { 1, 0 }); } else { neg++; increment(prediction, new int[] { 0, 1 }); } } /** * Adds the. * * @param label * the label * @param prediction * the prediction */ public void add(double label, double prediction) { increment(prediction, label); } /* pair should be in label, prediction order */ public void add(PrimitivePair pair) { add(pair.first, pair.second); } /** * Roc. * * @return the double[][] */ public double[][] ROC() { // checked: examples are in sorted order. // pos and neg are correct List curve = new ArrayList(); double tp = 0; double fp = 0; for (int[] counts : examples.values()) { curve.add(new PrimitivePair(fp / neg, tp / pos)); tp += counts[0]; fp += counts[1]; } curve.add(new PrimitivePair(fp / neg, tp / pos)); double[][] out = new double[curve.size()][2]; for (int i = 0; i < curve.size(); i++) { out[i][0] = curve.get(i).second; // tpr out[i][1] = curve.get(i).first; // fpr } return out; } /** * Brier score. * * @return the double */ public double brierScore() { double score = 0; double total = 0; for (Map.Entry entry : examples.entrySet()) { Double pred = entry.getKey(); int[] counts = entry.getValue(); score += counts[0] * (1 - pred) * (1 - pred); score += counts[1] * (0 - pred) * (0 - pred); total += counts[0] + counts[1]; } return score / total; } /** * bins the predictions. looks at the average label compared to the median * prediction for each bin. computes the brier score based on this * * @param bins * the bins * @return the double */ public double averagedBrierScore(int bins) { double score = 0; double predBins = Math.min(bins, pos + neg); double binWidth = 1. / predBins; double bottom = 0.; double top = bottom + binWidth; for (int i = 0; i < predBins; i++) { double num = 0; double avgLabel = 0; NavigableMap subMap = examples.subMap(bottom, true, top, true); for (int[] labels : subMap.values()) { avgLabel += labels[0]; num += labels[0] + labels[1]; } double medianscore = (bottom + top) / 2.; if (num > 0) { avgLabel /= num; score += (medianscore - avgLabel) * (medianscore - avgLabel); } top += binWidth; bottom += binWidth; } return score / predBins; } /** * Binary auc. * * @return the double */ public double binaryAUC() { double[][] ROC = ROC(); double area = 0.0; for (int i = 1; i < ROC.length; i++) { area += trapezoidArea(ROC[i - 1][1], ROC[i][1], ROC[i - 1][0], ROC[i][0]); } area += trapezoidArea(1, ROC[ROC.length - 1][1], 1, ROC[ROC.length - 1][0]); return area; } /** * Trapezoid area. * * @param x1 * the x1 * @param x2 * the x2 * @param y1 * the y1 * @param y2 * the y2 * @return the double */ private double trapezoidArea(double x1, double x2, double y1, double y2) { double base = Math.abs(x1 - x2); double avgHeight = (y1 + y2) / 2.0; return base * avgHeight; } /* * (non-Javadoc) * * @see java.lang.Object#toString() */ @Override public String toString() { StringBuffer buff = new StringBuffer(); double[][] out = this.ROC(); for (int i = 0; i < out.length; i++) { buff.append(out[i][0] + "\t" + out[i][1] + "\n"); } return buff.toString(); } /** * Compute auc. * * @param labelsAndPredictions * the labels and predictions * @return the double */ public static double computeAUC( Collection labelsAndPredictions) { ReceiverOperatingCharacteristic roc = new ReceiverOperatingCharacteristic(); for (PrimitivePair p : labelsAndPredictions) roc.add(p.first, p.second); return roc.binaryAUC(); } /** * Compute brier score. * * @param labelsAndPredictions * the labels and predictions * @return the double */ public static double computeBrierScore( Collection labelsAndPredictions) { ReceiverOperatingCharacteristic roc = new ReceiverOperatingCharacteristic(); for (PrimitivePair p : labelsAndPredictions) roc.add(p.first, p.second); return roc.averagedBrierScore(25); } } ================================================ FILE: src/main/java/com/etsy/conjecture/evaluation/RegressionModelEvaluation.java ================================================ package com.etsy.conjecture.evaluation; import java.io.Serializable; import java.util.HashMap; import com.etsy.conjecture.PrimitivePair; import com.etsy.conjecture.data.RealValuedLabel; /** * a basic container for evaluations TODO: add getters for individual metrics */ public class RegressionModelEvaluation implements ModelEvaluation, Serializable { private static final long serialVersionUID = 1L; private double MSE = 0, MAE = 0, examples = 0; public void add(RealValuedLabel real, RealValuedLabel pred) { add(real.getValue(), pred.getValue()); } public void merge(ModelEvaluation other) { RegressionModelEvaluation tempOther = (RegressionModelEvaluation) other; MSE += tempOther.MSE; MAE += tempOther.MAE; examples += tempOther.examples; } public void add(double label, double prediction) { double difference = Math.abs(label - prediction); MSE += difference * difference; MAE += difference; examples++; } public void add(PrimitivePair labelPrediction) { add(labelPrediction.getFirst(), labelPrediction.getSecond()); } public double computeMeanSquaredError() { return examples > 0 ? MSE / examples : 0; } public double computeMeanAbsoluteError() { return examples > 0 ? MAE / examples : 0; } public HashMap getStatistics() { HashMap m = new HashMap(); m.put("MSE", computeMeanSquaredError()); m.put("MAE", computeMeanAbsoluteError()); return m; } @Override public String toString() { StringBuilder buff = new StringBuilder(); buff.append("MSE: " + computeMeanSquaredError() + "\n"); buff.append("MAE: " + computeMeanAbsoluteError() + "\n"); return buff.toString(); } public HashMap getObjects() { HashMap m = new HashMap(); return m; } } ================================================ FILE: src/main/java/com/etsy/conjecture/model/AdagradOptimizer.java ================================================ package com.etsy.conjecture.model; import com.etsy.conjecture.*; import com.etsy.conjecture.data.*; import java.util.*; /** * AdaGrad provides adaptive per-feature learning rates at each time step t. * Described here: http://www.ark.cs.cmu.edu/cdyer/adagrad.pdf */ public class AdagradOptimizer extends SGDOptimizer { private StringKeyedVector unnormalizedGradients = new StringKeyedVector(); private StringKeyedVector summedGradients = new StringKeyedVector(); @Override public StringKeyedVector getUpdate(LabeledInstance instance) { StringKeyedVector gradients = model.getGradients(instance); StringKeyedVector updateVec = new StringKeyedVector(); Iterator> it = gradients.iterator(); while (it.hasNext()) { Map.Entry pairs = (Map.Entry)it.next(); String feature = pairs.getKey(); double gradient = pairs.getValue(); double featureLearningRate = updateAndGetFeatureLearningRate(feature, gradient); updateVec.setCoordinate(feature, gradient * -featureLearningRate); } return updateVec; } /** * Update adaptive feature specific learning rates */ public double updateAndGetFeatureLearningRate(String feature, double gradient) { double gradUpdate = 0.0; if (summedGradients.containsKey(feature)) { gradUpdate = gradient * gradient; } else { /** * Unmentioned in the literature, but initializing * the squared gradient at 1.0 rather than 0.0 * helps avoid oscillation. */ gradUpdate = 1d+(gradient * gradient); } summedGradients.addToCoordinate(feature, gradUpdate); unnormalizedGradients.addToCoordinate(feature, gradient); return getFeatureLearningRate(feature); } public double getFeatureLearningRate(String feature) { return initialLearningRate/Math.sqrt(summedGradients.getCoordinate(feature)); } /** * Overrides the lazy l1 and l2 regularization in the base class * to do adagrad with l1 regularization. * * Lazily calculates and applies the update that minimizes the l1 * regularized objective. See "Adding l1 regularization" in * http://www.ark.cs.cmu.edu/cdyer/adagrad.pdf */ @Override public double lazyUpdate(String feature, double param, long start, long end) { if (Utilities.floatingPointEquals(laplace, 0.0d)) { return param; } for (long iter = start + 1; iter <= end; iter++) { if (Utilities.floatingPointEquals(param, 0.0d)) { return 0.0d; } if (laplace > 0.0) { return adagradL1(feature, param, iter); } } return param; } public double adagradL1(String feature, double param, long iter) { double eta = (initialLearningRate*iter)/Math.sqrt(summedGradients.getCoordinate(feature)); double u = unnormalizedGradients.getCoordinate(feature); double normalizedGradient = u/iter; if (Math.abs(normalizedGradient) <= laplace) { param = 0.0; } else { param = -(Math.signum(u) * eta * (normalizedGradient - laplace)); } return param; } @Override public void teardown() { summedGradients = new StringKeyedVector(); unnormalizedGradients = new StringKeyedVector(); } } ================================================ FILE: src/main/java/com/etsy/conjecture/model/ClusteringModel.java ================================================ package com.etsy.conjecture.model; import com.etsy.conjecture.data.LabeledInstance; import com.etsy.conjecture.data.ClusterLabel; import com.etsy.conjecture.data.MulticlassPrediction; import com.etsy.conjecture.data.StringKeyedVector; import com.etsy.conjecture.data.LazyVector; import com.etsy.conjecture.Utilities; import com.etsy.conjecture.data.Label; import com.etsy.conjecture.data.LabeledInstance; import java.io.Serializable; import java.util.Collection; import java.util.Iterator; import java.util.List; import java.util.Map; import java.util.HashMap; import static com.google.common.base.Preconditions.checkArgument; public abstract class ClusteringModel implements UpdateableModel>, Serializable { static final long serialVersionUID = 666L; protected double projectionErrorTolerance = 0.01; protected double projectionBallRadius = 1.0; protected int numClusters = 100; protected Map param = new HashMap(); public void update(LabeledInstance instance) { update(instance.getVector()); } public void update(Collection> instances) { for(LabeledInstance instance : instances) { update(instance.getVector()); } } public abstract void update(StringKeyedVector instance); public abstract ClusterLabel predict(StringKeyedVector instance); protected ClusteringModel() { Map init_param = new HashMap(); for (int i = 0; i < numClusters; i++) { init_param.put(Integer.toString(i), new StringKeyedVector()); } this.param = init_param; } protected ClusteringModel(HashMap param) { Map init_param = new HashMap(); Iterator it = param.entrySet().iterator(); while (it.hasNext()) { Map.Entry pairs = (Map.Entry)it.next(); init_param.put(pairs.getKey(), pairs.getValue()); it.remove(); } this.param = init_param; } public void setFreezeFeatureSet(boolean freeze) { for(Map.Entry e : param.entrySet()) { e.getValue().setFreezeKeySet(freeze); } } public void reScale(double scale) { for(String cat : param.keySet()) { param.get(cat).mul(scale); } } public void merge(ClusteringModel model, double scale) { for(String cat : param.keySet()) { param.get(cat).addScaled(model.param.get(cat), scale); } } public ClusteringModel setNumClusters(int k) { checkArgument(k >= 0, "number of clusters must be non-negative, given: %s", k); this.numClusters = k; return this; } public ClusteringModel setL1ProjectionErrorTolerance(double e) { checkArgument(e >= 0, "error tolerance must be non-negative, given: %s", e); this.projectionErrorTolerance = e; return this; } public ClusteringModel setL1ProjectionBallRadius(double r) { checkArgument(r >= 0, "radius must be non-negative, given: %s", r); this.projectionBallRadius = r; return this; } public Iterator> decompose() { throw new UnsupportedOperationException("not done yet"); } public void setParameter(String name, double value){ throw new UnsupportedOperationException("not done yet"); } public long getEpoch() { return 0; } public void setEpoch(long epoch) { // this class doesnt care about epoch. } } ================================================ FILE: src/main/java/com/etsy/conjecture/model/ControlOptimizer.java ================================================ package com.etsy.conjecture.model; import com.etsy.conjecture.data.*; import java.util.*; /** * Current search ads control. Remove after current exp. */ public class ControlOptimizer extends SGDOptimizer { private StringKeyedVector summedGradients = new StringKeyedVector(); @Override public StringKeyedVector getUpdate(LabeledInstance instance) { StringKeyedVector gradients = model.getGradients(instance); StringKeyedVector updateVec = new StringKeyedVector(); Iterator> it = gradients.iterator(); while (it.hasNext()) { Map.Entry pairs = (Map.Entry)it.next(); String feature = pairs.getKey(); double gradient = pairs.getValue(); double featureLearningRate = updateAndGetFeatureLearningRate(feature, gradient); updateVec.setCoordinate(feature, gradient * -featureLearningRate); } return updateVec; } /** * Update adaptive feature specific learning rates */ public double updateAndGetFeatureLearningRate(String feature, double gradient) { double gradUpdate = 0.0; if (summedGradients.containsKey(feature)) { gradUpdate = gradient * gradient; } else { /** * Unmentioned in the literature, but initializing * the squared gradient at 1.0 rather than 0.0 * helps avoid oscillation. */ gradUpdate = 1d+(gradient * gradient); } summedGradients.addToCoordinate(feature, gradUpdate); return getFeatureLearningRate(feature); } public double getFeatureLearningRate(String feature) { return initialLearningRate/Math.sqrt(summedGradients.getCoordinate(feature)); } @Override public void teardown() { summedGradients = new StringKeyedVector(); } } ================================================ FILE: src/main/java/com/etsy/conjecture/model/Decomposable.java ================================================ package com.etsy.conjecture.model; import java.util.Iterator; import java.util.Map; /** * Type of model to be used with the LargeModelTrainer. */ public interface Decomposable { /** * Present the model internals to be summed across submodels. */ public Iterator> decompose(); /** * After rebuilding a blank model, fill in the parameters. */ public void setParameter(String name, double value); } ================================================ FILE: src/main/java/com/etsy/conjecture/model/ElasticNetOptimizer.java ================================================ package com.etsy.conjecture.model; import com.etsy.conjecture.data.*; public class ElasticNetOptimizer extends SGDOptimizer implements LazyVector.UpdateFunction { @Override public StringKeyedVector getUpdate(LabeledInstance instance) { StringKeyedVector gradients = model.getGradients(instance); double learningRate = getDecreasingLearningRate(model.epoch); gradients.mul(-learningRate); return gradients; } } ================================================ FILE: src/main/java/com/etsy/conjecture/model/FTRLOptimizer.java ================================================ package com.etsy.conjecture.model; import com.etsy.conjecture.data.LazyVector; import com.etsy.conjecture.data.StringKeyedVector; import static com.google.common.base.Preconditions.checkArgument; import com.etsy.conjecture.Utilities; import com.etsy.conjecture.data.LabeledInstance; import com.etsy.conjecture.data.Label; import java.util.Map; import java.util.Iterator; /** * Implements FTRL-Proximal online learning as described * here: http://static.googleusercontent.com/media/research.google.com/en/us/pubs/archive/41159.pdf */ public class FTRLOptimizer extends SGDOptimizer { private double alpha; private double beta; private StringKeyedVector z = new StringKeyedVector(); private StringKeyedVector n = new StringKeyedVector(); @Override public StringKeyedVector getUpdate(LabeledInstance instance) { FTRLRegularization(instance); StringKeyedVector gradients = model.getGradients(instance); Iterator> it = gradients.iterator(); while (it.hasNext()) { Map.Entry pairs = (Map.Entry)it.next(); String feature = pairs.getKey(); double gradient = pairs.getValue(); double eta = getFeatureLearningRate(feature, gradient); double z_i = 0.0; // if first round, set z_i to 0.0 if (z.containsKey(feature)) { z_i = z.getCoordinate(feature); } double update = (z_i + gradient) - eta * model.param.getCoordinate(feature); z.setCoordinate(feature, update); double n_i = 0.0; // if first round, set n_i to 0.0 if (n.containsKey(feature)) { n_i = n.getCoordinate(feature); } n.setCoordinate(feature, n_i + gradient * gradient); } return new StringKeyedVector(); // Model updates happen in the FTRLRegularization step } public double getFeatureLearningRate(String feature, double gradient) { double n_i = 0.0; if (n.containsKey(feature)) { n_i = n.getCoordinate(feature); } return 1d/alpha * (Math.sqrt(n_i + gradient * gradient) - Math.sqrt(n_i)); } public void FTRLRegularization(LabeledInstance instance) { Iterator> it = instance.getVector().iterator(); while (it.hasNext()) { Map.Entry pairs = (Map.Entry)it.next(); String feature = pairs.getKey(); Double value = pairs.getValue(); double regularizedWeight = getRegularizedWeight(feature); model.param.setCoordinate(feature, regularizedWeight); } } /** * If z doesn't contain the key, it's initialized at 0.0 * and therefore less than laplace which is always >= 0.0 */ public double getRegularizedWeight(String feature) { if (z.containsKey(feature)){ double z_i = z.getCoordinate(feature); if (Math.abs(z_i) <= laplace) { return 0.0d; } else { double n_i = n.getCoordinate(feature); double w_i = -1.0/(((beta + Math.sqrt(n_i))/alpha) + gaussian) * (z_i - Math.signum(z_i) * laplace); return w_i; } } else { return 0.0; } } /** * Since we can do sparse regularization updates, lazyUpdate * does nothing and just returns the feature param. */ @Override public double lazyUpdate(String feature, double param, long start, long end) { return param; } public FTRLOptimizer setAlpha(double alpha) { checkArgument(alpha > 0, "alpha must be greater than 0. Given: %s", alpha); this.alpha = alpha; return this; } public FTRLOptimizer setBeta(double beta) { checkArgument(beta > 0, "beta must be greater than 0. Given: %s", beta); this.beta = beta; return this; } @Override public void teardown() { z = new StringKeyedVector(); n = new StringKeyedVector(); } } ================================================ FILE: src/main/java/com/etsy/conjecture/model/Hinge.java ================================================ package com.etsy.conjecture.model; import com.etsy.conjecture.Utilities; import com.etsy.conjecture.data.BinaryLabel; import com.etsy.conjecture.data.LabeledInstance; import com.etsy.conjecture.data.StringKeyedVector; /** * Hinge loss for binary classification tasks with y in {-1,1}. * When threshold=1.0, one gets the loss used by SVM. * When threshold=0.0, one gets the loss used by the Perceptron. */ public class Hinge extends UpdateableLinearModel { private static final long serialVersionUID = 1L; private double threshold = 0.0; public Hinge(SGDOptimizer optimizer) { super(optimizer); } public Hinge(StringKeyedVector param, SGDOptimizer optimizer) { super(param, optimizer); } @Override public BinaryLabel predict(StringKeyedVector instance) { double inner = param.dot(instance); return new BinaryLabel(Utilities.logistic(inner)); } @Override public double loss(LabeledInstance instance) { double inner = param.dot(instance.getVector()); double label = instance.getLabel().getAsPlusMinus(); double z = inner * label; if (z <= this.threshold) { return this.threshold - z; } else { return 0.0; } } @Override public StringKeyedVector getGradients(LabeledInstance instance) { StringKeyedVector gradients = instance.getVector().copy(); double inner = param.dot(instance.getVector()); double label = instance.getLabel().getAsPlusMinus(); double z = inner * label; if (z <= this.threshold) { gradients.mul(-label); return gradients; } else { return new StringKeyedVector(); } } @Override protected String getModelType() { return "hinge"; } public Hinge setThreshold(double threshold) { this.threshold = threshold; return this; } } ================================================ FILE: src/main/java/com/etsy/conjecture/model/KMeans.java ================================================ package com.etsy.conjecture.model; import com.etsy.conjecture.Utilities; import com.etsy.conjecture.data.ClusterLabel; import com.etsy.conjecture.data.StringKeyedVector; import com.etsy.conjecture.data.ClusterPrediction; import com.etsy.conjecture.Utilities; import java.util.Collection; import java.util.Iterator; import java.util.List; import java.util.Map; import java.util.HashMap; import com.google.common.collect.Maps; /** * Implements sparse, streaming kmeans as described here: * http://www.eecs.tufts.edu/~dsculley/papers/fastkmeans.pdf */ public class KMeans extends ClusteringModel { private static final long serialVersionUID = 1L; private Map clusterCounts = new HashMap(); public KMeans(String[] categories) { for(String s : categories) { param.put(s, new StringKeyedVector()); clusterCounts.put(s, 0.0); } } private Map predefinedCenters; public KMeans(Map centers) { this.predefinedCenters = Maps.newHashMap(centers); for(String key : predefinedCenters.keySet()) { param.put(key, predefinedCenters.get(key)); clusterCounts.put(key, 0.0); } } public ClusterPrediction predict(StringKeyedVector instance) { Map scores = new HashMap(); for(Map.Entry e : param.entrySet()) { scores.put(e.getKey(), e.getValue().dot(instance)); } return new ClusterPrediction(scores); } public void update(StringKeyedVector instance) { // Get closest center to instance String closest_center = predict(instance).getLabel(); // Update the per center count Double current_count = clusterCounts.get(closest_center); clusterCounts.put(closest_center, current_count+1.0); // Get per center learning rate Double learning_rate = 1.0/clusterCounts.get(closest_center); // take gradient step StringKeyedVector center = param.get(closest_center); center.mul(1-learning_rate); instance.mul(learning_rate); center.add(instance); l1Projection(center); param.put(closest_center, center); } public Double getCurrent(StringKeyedVector center, Double theta) { Double current = 0.0; for (double v : center.values()) { current += Math.max(0, Math.abs(v)-theta); } return current; } /* * Use bisection to find an approximate value of theta */ public Double findTheta(StringKeyedVector center, Double norm) { Double upper = center.max(); Double lower = 0.0; Double current = norm; Double theta = 0.0; while (current > projectionBallRadius * (1 + projectionErrorTolerance)) { theta = (upper + lower)/2.0; current = getCurrent(center, theta); if (current <= projectionBallRadius) { upper = theta; } else { lower = theta; } } return theta; } public void doProjection(StringKeyedVector center, Double theta) { Iterator it = center.iterator(); while (it.hasNext()) { Map.Entry pairs = (Map.Entry)it.next(); String key = pairs.getKey(); double value = pairs.getValue(); double projectedValue = Math.signum(value) * Math.max(0.0, Math.abs(value) - theta); center.setCoordinate(key, projectedValue); } } /** * An e-Accurate projection to the L1 ball, described here: * http://www.eecs.tufts.edu/~dsculley/papers/fastkmeans.pdf */ public void l1Projection(StringKeyedVector center) { Double norm = center.LPNorm(1.0); if (norm <= projectionBallRadius + projectionErrorTolerance) { return; } else { Double theta = findTheta(center, norm); doProjection(center, theta); } } } ================================================ FILE: src/main/java/com/etsy/conjecture/model/LeastSquaresRegressionModel.java ================================================ package com.etsy.conjecture.model; import com.etsy.conjecture.data.LabeledInstance; import com.etsy.conjecture.data.RealValuedLabel; import com.etsy.conjecture.data.StringKeyedVector; public class LeastSquaresRegressionModel extends UpdateableLinearModel { private static final long serialVersionUID = 1L; public LeastSquaresRegressionModel(SGDOptimizer optimizer) { super(optimizer); } public LeastSquaresRegressionModel(StringKeyedVector param, SGDOptimizer optimizer) { super(param, optimizer); } @Override public RealValuedLabel predict(StringKeyedVector instance) { return new RealValuedLabel(param.dot(instance)); } @Override public double loss (LabeledInstance instance) { double label = instance.getLabel().getValue(); double hypothesis = param.dot(instance.getVector()); return 0.5 * (hypothesis - label) * (hypothesis - label); } @Override public StringKeyedVector getGradients(LabeledInstance instance) { StringKeyedVector gradients = instance.getVector().copy(); double hypothesis = param.dot(instance.getVector()); double label = instance.getLabel().getValue(); gradients.mul((2 * (hypothesis-label))); return gradients; } @Override protected String getModelType() { return "least_squares_regression"; } } ================================================ FILE: src/main/java/com/etsy/conjecture/model/LogisticRegression.java ================================================ package com.etsy.conjecture.model; import com.etsy.conjecture.Utilities; import com.etsy.conjecture.data.BinaryLabel; import com.etsy.conjecture.data.LabeledInstance; import com.etsy.conjecture.data.StringKeyedVector; /** * Logistic regression loss for binary classification with y in {-1, 1}. */ public class LogisticRegression extends UpdateableLinearModel { private static final long serialVersionUID = 1L; public LogisticRegression(SGDOptimizer optimizer) { super(optimizer); } public LogisticRegression(StringKeyedVector param, SGDOptimizer optimizer) { super(param, optimizer); } @Override public BinaryLabel predict(StringKeyedVector instance) { return new BinaryLabel(Utilities.logistic(instance.dot(param))); } @Override public double loss(LabeledInstance instance) { double inner = instance.getVector().dot(param); double label = instance.getLabel().getAsPlusMinus(); return Math.log(1.0 + Math.exp(-label * inner)); } @Override public StringKeyedVector getGradients(LabeledInstance instance) { StringKeyedVector gradients = instance.getVector().copy(); double label = instance.getLabel().getAsPlusMinus(); double inner = instance.getVector().dot(param); double gradient = -label / (Math.exp(label * inner) + 1.0); gradients.mul(gradient); return gradients; } protected String getModelType() { return "logistic_regression"; } } ================================================ FILE: src/main/java/com/etsy/conjecture/model/MIRA.java ================================================ package com.etsy.conjecture.model; import com.etsy.conjecture.Utilities; import com.etsy.conjecture.data.BinaryLabel; import com.etsy.conjecture.data.LabeledInstance; import com.etsy.conjecture.data.StringKeyedVector; public class MIRA extends UpdateableLinearModel { private static final long serialVersionUID = 1L; public MIRA() { super(new MIRAOptimizer()); } public MIRA(StringKeyedVector param, SGDOptimizer optimizer) { super(param, optimizer); } @Override public double loss(LabeledInstance instance) { double label = instance.getLabel().getAsPlusMinus(); double prediction = param.dot(instance.getVector()); double loss = Math.max(0, 1d - label * prediction); return loss; } @Override public BinaryLabel predict(StringKeyedVector instance) { double inner = param.dot(instance); return new BinaryLabel(Utilities.logistic(inner)); } @Override public StringKeyedVector getGradients(LabeledInstance instance) { StringKeyedVector gradients = instance.getVector().copy(); double label = instance.getLabel().getAsPlusMinus(); double prediction = param.dot(instance.getVector()); double loss = Math.max(0, 1d - label * prediction); if (loss > 0) { double norm = instance.getVector().LPNorm(2d); double tau = loss / (norm * norm); gradients.mul(tau * label); return gradients; } else { return new StringKeyedVector(); } } @Override protected String getModelType() { return "MIRA"; } } ================================================ FILE: src/main/java/com/etsy/conjecture/model/MIRAOptimizer.java ================================================ package com.etsy.conjecture.model; import com.etsy.conjecture.data.*; /** * MIRA takes care of the full update. This is basically just a passthrough to * the MIRA getGradients. */ public class MIRAOptimizer extends SGDOptimizer { @Override public StringKeyedVector getUpdate(LabeledInstance instance) { return model.getGradients(instance); } } ================================================ FILE: src/main/java/com/etsy/conjecture/model/Model.java ================================================ package com.etsy.conjecture.model; import java.io.Serializable; import com.etsy.conjecture.data.Label; import com.etsy.conjecture.data.StringKeyedVector; public interface Model extends Serializable { public L predict(StringKeyedVector instance); } ================================================ FILE: src/main/java/com/etsy/conjecture/model/PassiveAggressiveOptimizer.java ================================================ package com.etsy.conjecture.model; import com.etsy.conjecture.data.LazyVector; import com.etsy.conjecture.data.StringKeyedVector; import static com.google.common.base.Preconditions.checkArgument; import com.etsy.conjecture.Utilities; import com.etsy.conjecture.data.Label; import com.etsy.conjecture.data.LabeledInstance; import com.etsy.conjecture.data.RealValuedLabel; /** * See http://eprints.pascal-network.org/archive/00002147/01/CrammerDeKeShSi06.pdf * for a discussion of PA Regression. */ public class PassiveAggressiveOptimizer extends SGDOptimizer { private double C; private boolean isHinge; @Override public StringKeyedVector getUpdate(LabeledInstance instance) { double norm = instance.getVector().LPNorm(2d); double update = model.loss(instance) / (norm * norm + 0.5 / C); if(isHinge) { /** * Classification. Scale update by label in {-1, 1}. */ update = update * (2.0 * (instance.getLabel().getValue() - 0.5)); } else if (instance.getLabel().getValue() - ((RealValuedLabel)model.predict(instance.getVector())).getValue() < 0.0) { /** Regression **/ update = update * -1; } StringKeyedVector updateVec = instance.getVector().copy(); updateVec.mul(update); return updateVec; } public PassiveAggressiveOptimizer setC(double C) { checkArgument(C > 0, "C must be greater than 0. Given: %s", C); this.C = C; return this; } public PassiveAggressiveOptimizer isHinge(boolean isHinge) { this.isHinge = isHinge; return this; } } ================================================ FILE: src/main/java/com/etsy/conjecture/model/SGDOptimizer.java ================================================ package com.etsy.conjecture.model; import com.etsy.conjecture.data.LazyVector; import com.etsy.conjecture.Utilities; import static com.google.common.base.Preconditions.checkArgument; import com.etsy.conjecture.data.Label; import com.etsy.conjecture.data.LabeledInstance; import com.etsy.conjecture.data.StringKeyedVector; import java.util.Collection; /** * Builds the weight updates as a function * of learning rate and regularization schedule for SGD learning. * * Default learning rate and regularization are: * LR: Exponentially decreasing * REG: Lazily applied L1 and L2 regularization * Subclasses overwrite LR and REG functions as necessary */ public abstract class SGDOptimizer implements LazyVector.UpdateFunction { private static final long serialVersionUID = 9153480933266800474L; double laplace = 0.0; double gaussian = 0.0; double initialLearningRate = 0.01; transient UpdateableLinearModel model; double examplesPerEpoch = 10000; boolean useExponentialLearningRate = false; double exponentialLearningRateBase = 0.99; public SGDOptimizer() {} public SGDOptimizer(double g, double l) { gaussian = g; laplace = l; } /** * Do minibatch gradient descent */ public StringKeyedVector getUpdates(Collection> minibatch) { StringKeyedVector updateVec = new StringKeyedVector(); for (LabeledInstance instance : minibatch) { updateVec.add(getUpdate(instance)); // accumulate gradient model.truncate(instance); model.epoch++; } updateVec.mul(1.0/minibatch.size()); // do a single update, scaling weights by the // average gradient over the minibatch return updateVec; } /** * Get the update to the param vector using a chosen * learning rate / regularization schedule. * Returns a StringKeyedVector of updates for each * parameter. */ public abstract StringKeyedVector getUpdate(LabeledInstance instance); public void teardown() { } /** * Implements lazy updating of regularization when the regularization * updates aren't sparse (e.g. elastic net l1 and l2, adagrad l1). * * When regularization can be done on just the non-zero elements of * the sample instance (e.g. FTRL proximal, HandsFree), the lazyUpdate * function does nothing (i.e. just returns the unscaled param). */ public double lazyUpdate(String feature, double param, long start, long end) { if (Utilities.floatingPointEquals(laplace, 0.0d) && Utilities.floatingPointEquals(gaussian, 0.0d)) { return param; } for (long iter = start + 1; iter <= end; iter++) { if (Utilities.floatingPointEquals(param, 0.0d)) { return 0.0d; } double eta = getDecreasingLearningRate(iter); /** * TODO: patch so that param cannot cross 0.0 during gaussian update */ param -= eta * gaussian * param; if (param > 0.0) { param = Math.max(0.0, param - eta * laplace); } else { param = Math.min(0.0, param + eta * laplace); } } return param; } /** * Computes a linearly or exponentially decreasing * learning rate as a function of the current epoch. * Even when we have per feature learning rates, it's * necessary to keep track of a decreasing learning rate * for things like truncation. */ public double getDecreasingLearningRate(long t){ double epoch_fudged = Math.max(1.0, (t + 1) / examplesPerEpoch); if (useExponentialLearningRate) { return Math.max( 0d, this.initialLearningRate * Math.pow(this.exponentialLearningRateBase, epoch_fudged)); } else { return Math.max(0d, this.initialLearningRate / epoch_fudged); } } public SGDOptimizer setInitialLearningRate(double rate) { checkArgument(rate > 0, "Initial learning rate must be greater than 0. Given: %s", rate); this.initialLearningRate = rate; return this; } public SGDOptimizer setExamplesPerEpoch(double examples) { checkArgument(examples > 0, "examples per epoch must be positive, given %f", examples); this.examplesPerEpoch = examples; return this; } public SGDOptimizer setUseExponentialLearningRate(boolean useExponentialLearningRate) { this.useExponentialLearningRate = useExponentialLearningRate; return this; } public SGDOptimizer setExponentialLearningRateBase(double base) { checkArgument(base > 0, "exponential learning rate base must be positive, given: %f", base); checkArgument( base <= 1.0, "exponential learning rate base must be at most 1.0, given: %f", base); this.exponentialLearningRateBase = base; return this; } public SGDOptimizer setGaussianRegularizationWeight(double gaussian) { checkArgument(gaussian >= 0.0, "gaussian regularization weight must be non-negative, given: %f", gaussian); this.gaussian = gaussian; return this; } public SGDOptimizer setLaplaceRegularizationWeight(double laplace) { checkArgument(laplace >= 0.0, "laplace regularization weight must be non-negative, given: %f", laplace); this.laplace = laplace; return this; } } ================================================ FILE: src/main/java/com/etsy/conjecture/model/UpdateableLinearModel.java ================================================ package com.etsy.conjecture.model; import static com.google.common.base.Preconditions.checkArgument; import gnu.trove.function.TDoubleFunction; import java.io.Serializable; import java.util.ArrayList; import java.util.Collection; import java.util.HashMap; import java.util.Iterator; import java.util.Map; import com.etsy.conjecture.Utilities; import com.etsy.conjecture.data.Label; import com.etsy.conjecture.data.LabeledInstance; import com.etsy.conjecture.data.LazyVector; import com.etsy.conjecture.data.StringKeyedVector; public abstract class UpdateableLinearModel implements UpdateableModel>, Comparable>, Serializable { private static final long serialVersionUID = 8549108867384062857L; protected LazyVector param; protected final String modelType; protected long epoch; protected SGDOptimizer optimizer; // parameters for gradient truncation // for more info, see: // http://jmlr.csail.mit.edu/papers/volume10/langford09a/langford09a.pdf protected int period = 0; protected double truncationUpdate = 0.1; protected double truncationThreshold = 0.0; private String argString = "NOT SET"; public void setArgString(String s) { argString = s; } public String getArgString() { return argString; } public double dotWithParam(StringKeyedVector x) { return param.dot(x); } protected UpdateableLinearModel(SGDOptimizer optimizer) { this.optimizer = optimizer; this.param = new LazyVector(100, optimizer); epoch = 0; modelType = getModelType(); } protected UpdateableLinearModel(StringKeyedVector param, SGDOptimizer optimizer) { this.optimizer = optimizer; optimizer.model = this; this.param = new LazyVector(param, optimizer); epoch = 0; modelType = getModelType(); } /** * Get a StringKeyedVector holding the gradient of the loss w.r.t. every model parameter. */ public abstract StringKeyedVector getGradients(LabeledInstance instance); /** * Minibatch gradient update */ public void update(Collection> instances) { optimizer.model = this; // avoid serialization stackoverflow if (epoch > 0) { param.incrementIteration(); } StringKeyedVector updates = optimizer.getUpdates(instances); param.add(updates); } /** * Single gradient update */ public void update(LabeledInstance instance) { optimizer.model = this; // avoid serialization stackoverflow if (epoch > 0) { param.incrementIteration(); } StringKeyedVector update = optimizer.getUpdate(instance); param.add(update); truncate(instance); epoch++; } public abstract L predict(StringKeyedVector instance); public abstract double loss(LabeledInstance instance); protected abstract String getModelType(); public Iterator> decompose() { return param.iterator(); } public void setParameter(String name, double value) { param.setCoordinate(name, value); } public StringKeyedVector getParam() { return param; } public void reScale(double scale) { param.mul(scale); } public void setFreezeFeatureSet(boolean freeze) { param.setFreezeKeySet(freeze); } public void merge(UpdateableLinearModel model, double scaling) { param.addScaled(model.param, scaling); epoch += model.epoch; } public void teardown() { optimizer.teardown(); } /** * Decide based on period and epoch whether to truncate */ public void truncate(LabeledInstance instance) { if (period > 0 && epoch > 0 && epoch % period == 0) { applyTruncation(instance.getVector()); } } public void applyTruncation(StringKeyedVector instance) { final double update = this.optimizer.getDecreasingLearningRate(epoch) * truncationUpdate; final double threshold = truncationThreshold; TDoubleFunction truncFn = new TDoubleFunction() { public double execute(double parameter) { if (parameter > 0 && parameter < threshold) { return Math.max(0, parameter - update); } else if (parameter < 0 && parameter > -threshold) { return Math.min(0, parameter + update); } else { return parameter; } } }; param.transformValues(truncFn); param.removeZeroCoordinates(); } public long getEpoch() { return epoch; } public void setEpoch(long e) { epoch = e; } public UpdateableLinearModel setTruncationPeriod(int period) { checkArgument(period >= 0, "period must be non-negative, given: %s", period); this.period = period; return this; } public UpdateableLinearModel setTruncationThreshold(double threshold) { checkArgument(threshold >= 0, "update must be non-negative, given: %s", threshold); this.truncationThreshold = threshold; return this; } public UpdateableLinearModel setTruncationUpdate(double update) { checkArgument(update >= 0, "update must be non-negative, given: %s", update); this.truncationUpdate = update; return this; } @Override public int compareTo(UpdateableLinearModel inputModel) { return (int)Math.signum(inputModel.param.LPNorm(2d) - param.LPNorm(2d)); } public void thresholdParameters(double t) { for (Iterator> it = param.iterator(); it .hasNext();) { if (Math.abs(it.next().getValue()) < t) { it.remove(); } } } public String explainPrediction(StringKeyedVector x) { return explainPrediction(x, -1); } public String explainPrediction(StringKeyedVector x, int n) { StringBuilder out = new StringBuilder(); Map weights = new HashMap(); for (String dim : x.keySet()) { if (param.getCoordinate(dim) != 0.0) { weights.put( dim, Math.abs(x.getCoordinate(dim) * param.getCoordinate(dim))); } } ArrayList keys = com.etsy.conjecture.Utilities .orderKeysByValue(weights, true); for (int i = 0; (n == -1 || i < n) && i < keys.size(); i++) { String k = keys.get(i); out.append(k + ":" + String.format("%.2f", x.getCoordinate(k)) + "->" + String.format("%.2f", param.getCoordinate(k)) + " "); } return out.toString(); } } ================================================ FILE: src/main/java/com/etsy/conjecture/model/UpdateableModel.java ================================================ package com.etsy.conjecture.model; import java.util.Collection; import com.etsy.conjecture.data.Label; import com.etsy.conjecture.data.LabeledInstance; public interface UpdateableModel> extends Model, Decomposable { // - update the model with a single labeled instance. public void update(LabeledInstance instance); // - update the model with many labeled instances. public void update(Collection> instances); // - merge two models together. public void merge(M model, double weight); // - multiply the parameter vector by a constant. public void reScale(double scale); // - set whether to add unseen-features when updating. public void setFreezeFeatureSet(boolean freeze); // - reset the epoch number after model merging. public void setEpoch(long epoch); public long getEpoch(); } ================================================ FILE: src/main/java/com/etsy/conjecture/model/UpdateableMulticlassLinearModel.java ================================================ package com.etsy.conjecture.model; import static com.google.common.base.Preconditions.checkArgument; import gnu.trove.function.TDoubleFunction; import java.io.Serializable; import java.util.ArrayList; import java.util.Collection; import java.util.HashMap; import java.util.Iterator; import java.util.Map; import com.etsy.conjecture.Utilities; import com.etsy.conjecture.data.MulticlassLabel; import com.etsy.conjecture.data.LabeledInstance; import com.etsy.conjecture.data.BinaryLabeledInstance; import com.etsy.conjecture.data.MulticlassLabeledInstance; import com.etsy.conjecture.data.MulticlassPrediction; import com.etsy.conjecture.data.LazyVector; import com.etsy.conjecture.data.StringKeyedVector; import com.etsy.conjecture.data.RealValuedLabel; import com.etsy.conjecture.data.BinaryLabel; public class UpdateableMulticlassLinearModel implements UpdateableModel, Comparable, Serializable { private static final long serialVersionUID = 8549108867384062857L; protected String modelType; private String argString = "NOT SET"; protected long epoch; protected Map> param = new HashMap>(); public UpdateableMulticlassLinearModel(Map> param) { this.param = param; this.epoch = 0; this.modelType = this.getModelType(); } public void setArgString(String s) { argString = s; } public String getArgString() { return argString; } public void setModelType(String modelType) { this.modelType = modelType; } public String getModelType() { return modelType; } public Iterator> decompose() { throw new UnsupportedOperationException("not done yet"); } public void setParameter(String name, double value) { throw new UnsupportedOperationException("not done yet"); } public void reScale(double scale) { for (String cat : param.keySet()) { param.get(cat).param.mul(scale); } } public void setFreezeFeatureSet(boolean freeze) { for (Map.Entry> e : param.entrySet()) { e.getValue().param.setFreezeKeySet(freeze); } } /** * Minibatch gradient update */ public void update(Collection> instances) { for (LabeledInstance instance : instances) { update(instance); } } /** * Single gradient update. */ public void update(LabeledInstance instance) { for (Map.Entry> e : param.entrySet()) { String category = e.getKey(); UpdateableLinearModel model = e.getValue(); double label = e.getKey().equals(instance.getLabel().getLabel()) ? 1.0 : 0.0; BinaryLabeledInstance blInstance = new BinaryLabeledInstance(label, instance.getVector()); model.update(blInstance); } epoch++; } @Override public MulticlassPrediction predict(StringKeyedVector instance) { Map scores = new HashMap(); double normalization = 0; for (Map.Entry> e : param.entrySet()) { double prediction = ((RealValuedLabel)e.getValue().predict(instance)).getValue(); scores.put(e.getKey(), prediction); normalization += prediction; } for (Map.Entry e : scores.entrySet()) { scores.put(e.getKey(), e.getValue() / normalization); } return new MulticlassPrediction(scores); } public void merge(UpdateableMulticlassLinearModel model, double scale) { for (String cat : param.keySet()) { param.get(cat).param.addScaled(model.param.get(cat).param, scale); } epoch += model.epoch; } public void teardown() { for (Map.Entry> e : param.entrySet()) { e.getValue().teardown(); } } public long getEpoch() { return epoch; } public void setEpoch(long e) { epoch = e; } // what to do here? @Override public int compareTo(UpdateableMulticlassLinearModel inputModel) { return (int)Math.signum(inputModel.getEpoch() - getEpoch()); } public void thresholdParameters(double t) { for (UpdateableLinearModel m : param.values()) { for (Iterator> it = m.param.iterator(); it .hasNext();) { if (Math.abs(it.next().getValue()) < t) { it.remove(); } } } } public String explainPrediction(StringKeyedVector x) { return explainPrediction(x, -1); } public String explainPrediction(StringKeyedVector x, int n) { throw new UnsupportedOperationException("not done yet"); } } ================================================ FILE: src/main/java/com/etsy/conjecture/topics/lda/LDADenseTopics.java ================================================ package com.etsy.conjecture.topics.lda; import java.io.Serializable; import java.util.Random; public class LDADenseTopics implements LDATopics, Serializable { private static final long serialVersionUID = 8704084406257021101L; int num_topics; int dict_size; double[][] topic_prob; LDADict dict; Random rnd = new Random(); public LDADenseTopics(double[][] topic_prob) { this.num_topics = topic_prob.length; this.dict_size = topic_prob[0].length; this.topic_prob = topic_prob; } public void setTopicProb(int topic, double[] prob) { topic_prob[topic] = prob; } public void setDict(LDADict dict_) throws Exception { if (dict_.size() < dict_size) throw new Exception("trying to set the dict with size " + dict_.size() + " on a topic model with dict size " + dict_size); dict = dict_; dict_size = dict.size(); } public LDADict getDict() { return dict; } public double wordProb(int word, int topic) { return topic_prob[topic][word]; } public int numTopics() { return num_topics; } public int dictSize() { return dict_size; } public String toString() { StringBuilder b = new StringBuilder(); for (int k = 0; k < num_topics; k++) { b.append(k + ": "); for (int w = 0; w < dict_size; w++) { if (dict == null) b.append(w + ":" + String.format("%.3f, ", topic_prob[k][w])); else b.append(dict.word(w) + ":" + String.format("%.3f, ", topic_prob[k][w])); } b.append("\n"); } return b.toString(); } } ================================================ FILE: src/main/java/com/etsy/conjecture/topics/lda/LDADict.java ================================================ package com.etsy.conjecture.topics.lda; import java.io.Serializable; import java.util.ArrayList; import java.util.HashMap; import java.util.Set; public class LDADict implements Serializable { private static final long serialVersionUID = 2363682000942209420L; private ArrayList words; private HashMap dict; public LDADict(Set unique_words) { words = new ArrayList(unique_words.size()); dict = new HashMap(); for (String s : unique_words) { words.add(s); dict.put(s, dict.size()); } } public String word(int index) { return words.get(index); } public int index(String word) { return dict.get(word); } public int size() { return words.size(); } public boolean contains(String word) { return dict.containsKey(word); } public String toString() { return "LDADict(size: " + size() + ")"; } } ================================================ FILE: src/main/java/com/etsy/conjecture/topics/lda/LDADoc.java ================================================ package com.etsy.conjecture.topics.lda; import com.etsy.conjecture.Utilities; import java.io.Serializable; import java.util.ArrayList; import java.util.Arrays; import java.util.HashMap; import java.util.Map; public class LDADoc implements Serializable { private static final long serialVersionUID = 1536967875771864807L; double[] topic_proportions; double total_words; int[] word_idx; double[] word_count; double[][] phi; boolean phi_dirty; public LDADoc(Map word_counts, LDADict dict) { total_words = 0.0; word_idx = new int[word_counts.size()]; word_count = new double[word_counts.size()]; phi_dirty = true; int i = 0; for (Map.Entry e : word_counts.entrySet()) { word_idx[i++] = dict.index(e.getKey()); total_words += e.getValue(); } // Keep parallel arrays in sorted order of word index, for easier // aggregation // of partial topic models. Arrays.sort(word_idx); for (int w = 0; w < word_idx.length; w++) { word_count[w] = word_counts.get(dict.word(word_idx[w])); } } public double[] topicProportions() { return topic_proportions; } public double wordCount() { return total_words; } public void updateTopicProportions(LDATopics topics, double alpha) { int K = topics.numTopics(); // reuse old topic proportions unless the topic model has changed. if (topic_proportions == null || topic_proportions.length != topics.numTopics()) { topic_proportions = new double[K]; for (int k = 0; k < K; k++) { topic_proportions[k] = total_words / (double)K; } } if (phi == null || phi[0].length != topics.numTopics()) { phi = new double[word_idx.length][K]; } // iterate the update procedure. double[] topic_proportions_new = new double[K]; double[] phi_z = new double[word_idx.length]; while (true) { // Compute phi. for (int k = 0; k < K; k++) { double digamma_k = LDAUtils.digamma(topic_proportions[k]); for (int w = 0; w < word_idx.length; w++) { double wp = Math.log(topics.wordProb(word_idx[w], k)); phi[w][k] = digamma_k + wp; if (k == 0) { phi_z[w] = phi[w][k]; } else { phi_z[w] = LDAUtils.logSumExp(phi_z[w], phi[w][k]); } } } // Compute updated gamma. double conv = 0.0; for (int k = 0; k < K; k++) { topic_proportions_new[k] = alpha; for (int w = 0; w < word_idx.length; w++) { phi[w][k] = Math.exp(phi[w][k] - phi_z[w]) * word_count[w]; topic_proportions_new[k] += phi[w][k]; } double diff = topic_proportions[k] - topic_proportions_new[k]; topic_proportions[k] = topic_proportions_new[k]; conv += diff * diff; } // Check convergence. if (conv < 1000.0) { break; } } phi_dirty = false; } // You can only call this after calling updateTopicProportions.. public LDAPartialTopics toPartialTopics() throws Exception { if (phi_dirty) { throw new Exception( "Called toPartialTopics() on a doc that hasnt been updated"); } return new LDAPartialTopics(word_idx, phi); } public LDAPartialTopics toPartialTopic(int topic) throws Exception { if (phi_dirty) { throw new Exception( "Called toPartialTopics() on a doc that hasnt been updated"); } double[][] phi_k = new double[word_idx.length][1]; // duh for (int i = 0; i < word_idx.length; i++) { phi_k[i][0] = phi[i][topic]; } return new LDAPartialTopics(word_idx, phi_k); } public LDAPartialSparseTopics toPartialSparseTopics(int n) throws Exception { if (phi_dirty) { throw new Exception( "Called toPartialTopics() on a doc that hasnt been updated"); } int K = topic_proportions.length; Map partial_phi = new HashMap(); Map word_topic_prob = new HashMap(); for (int w = 0; w < word_idx.length; w++) { word_topic_prob.clear(); for (int k = 0; k < K; k++) { word_topic_prob.put(k, phi[w][k]); } ArrayList sorted_topics = Utilities.orderKeysByValue( word_topic_prob, true); double z = 0.0; for (int i = 0; i < n; i++) { z += phi[w][sorted_topics.get(i)]; } word_topic_prob.clear(); for (int i = 0; i < n; i++) { int k = sorted_topics.get(i); int v = word_idx[w]; partial_phi.put(v * K + k, (phi[w][k] / z) * word_count[w]); } } return new LDAPartialSparseTopics(K, partial_phi); } } ================================================ FILE: src/main/java/com/etsy/conjecture/topics/lda/LDAPartialSparseTopics.java ================================================ package com.etsy.conjecture.topics.lda; import java.io.Serializable; import java.util.Map; import java.util.Set; public class LDAPartialSparseTopics implements Serializable { private static final long serialVersionUID = -5073459183590344302L; private int K; private Map phi; public LDAPartialSparseTopics(int K, Map phi) { this.K = K; this.phi = phi; } public LDAPartialSparseTopics merge(LDAPartialSparseTopics rhs) throws Exception { if (K != rhs.K) { throw new Exception( "Try to merge partials with different nubmer of topics: " + K + " and " + rhs.K); } Map a = phi.size() < rhs.phi.size() ? phi : rhs.phi; Map b = phi.size() < rhs.phi.size() ? rhs.phi : phi; for (Map.Entry e : a.entrySet()) { if (b.containsKey(e.getKey())) { b.put(e.getKey(), e.getValue() + b.get(e.getKey())); } else { b.put(e.getKey(), e.getValue()); } } return new LDAPartialSparseTopics(K, b); } public LDASparseTopics toTopics() { // renormalize. double[] z = new double[K]; for (Map.Entry e : phi.entrySet()) { z[e.getKey() % K] += e.getValue(); } Set keys = phi.keySet(); for (int i : keys) { phi.put(i, phi.get(i) / z[i % K]); } return new LDASparseTopics(K, phi); } } ================================================ FILE: src/main/java/com/etsy/conjecture/topics/lda/LDAPartialTopics.java ================================================ package com.etsy.conjecture.topics.lda; import java.io.Serializable; public class LDAPartialTopics implements Serializable { private static final long serialVersionUID = 3590284302630767864L; private int[] word_index; private double[][] phi; public LDAPartialTopics(int[] word_index, double[][] phi) { this.word_index = word_index; this.phi = phi; } private int countUniqueWords(LDAPartialTopics rhs) { // First determine the number of unique words in both sides. int num_words = 0; int lhs_idx = 0; int rhs_idx = 0; while (lhs_idx < word_index.length && rhs_idx < rhs.word_index.length) { int lhs_word = word_index[lhs_idx]; int rhs_word = rhs.word_index[rhs_idx]; if (lhs_word <= rhs_word) { lhs_idx++; } if (rhs_word <= lhs_word) { rhs_idx++; } num_words++; } // add word for whatever pointers not reached the end if (lhs_idx != word_index.length) { num_words += word_index.length - lhs_idx; } if (rhs_idx != rhs.word_index.length) { num_words += rhs.word_index.length - rhs_idx; } return num_words; } public LDAPartialTopics merge(LDAPartialTopics rhs) throws Exception { if (phi[0].length != rhs.phi[0].length) { throw new Exception( "Try to merge partials with different nubmer of topics: " + phi.length + " and " + rhs.phi.length); } int K = phi[0].length; int num_words = countUniqueWords(rhs); int[] word_idx_new = new int[num_words]; double[][] phi_new = new double[num_words][K]; int new_idx = 0; int lhs_idx = 0; int rhs_idx = 0; while (lhs_idx < word_index.length && rhs_idx < rhs.word_index.length) { int lhs_word = word_index[lhs_idx]; int rhs_word = rhs.word_index[rhs_idx]; if (lhs_word < rhs_word) { word_idx_new[new_idx] = lhs_word; for (int k = 0; k < K; k++) { phi_new[new_idx][k] = phi[lhs_idx][k]; } lhs_idx++; } else if (rhs_word < lhs_word) { word_idx_new[new_idx] = rhs_word; for (int k = 0; k < K; k++) { phi_new[new_idx][k] = rhs.phi[rhs_idx][k]; } rhs_idx++; } else { word_idx_new[new_idx] = rhs_word; for (int k = 0; k < K; k++) { phi_new[new_idx][k] = rhs.phi[rhs_idx][k] + phi[lhs_idx][k]; } rhs_idx++; lhs_idx++; } new_idx++; } // add word for whatever pointers not reached the end for (; lhs_idx < word_index.length; lhs_idx++) { int lhs_word = word_index[lhs_idx]; word_idx_new[new_idx] = lhs_word; for (int k = 0; k < K; k++) { phi_new[new_idx][k] = phi[lhs_idx][k]; } new_idx++; } for (; rhs_idx != rhs.word_index.length; rhs_idx++) { int rhs_word = rhs.word_index[rhs_idx]; word_idx_new[new_idx] = rhs_word; for (int k = 0; k < K; k++) { phi_new[new_idx][k] = rhs.phi[rhs_idx][k]; } new_idx++; } return new LDAPartialTopics(word_idx_new, phi_new); } public String toString() { StringBuilder b = new StringBuilder(); for (int i = 0; i < word_index.length; i++) { b.append(i + " - " + word_index[i] + ": "); for (int k = 0; k < phi[i].length; k++) { b.append(phi[i][k] + ", "); } b.append("\n"); } return b.toString(); } public double[][] toTopicVectors() { // Ensure that words_index has no gaps. int word_max = word_index[word_index.length - 1]; int K = phi[0].length; double[][] phi_new = new double[K][word_max + 1]; for (int k = 0; k < K; k++) { double z = 0.0; for (int i = 0; i < word_index.length; i++) { int w = word_index[i]; phi_new[k][w] = phi[i][k]; z += phi[i][k]; } for (int i = 0; i < phi_new[k].length; i++) { phi_new[k][i] /= z; } } return phi_new; } public double[] toTopicVector() throws Exception { double[][] phi_new = toTopicVectors(); if (phi_new.length > 1) { throw new Exception( "called toTopicVector() on a thing with multiple vectors"); } return phi_new[0]; } public LDADenseTopics toTopics() { return new LDADenseTopics(toTopicVectors()); } public static void main(String[] argv) throws Exception { int[] words_lhs = new int[] { 1, 2, 4, 7, 10 }; double[][] phi_lhs = new double[][] { { 0.3, 0.7 }, { 0.2, 0.8 }, { 0.1, 0.9 }, { 0.5, 0.5 }, { 0.9, 0.1 } }; int[] words_rhs = new int[] { 4, 10, 11, 12, 15 }; double[][] phi_rhs = new double[][] { { 0.4, 0.6 }, { 0.3, 0.7 }, { 0.3, 0.7 }, { 0.1, 0.9 }, { 0.4, 0.6 }, { 0.5, 0.5 } }; LDAPartialTopics lhs = new LDAPartialTopics(words_lhs, phi_lhs); LDAPartialTopics rhs = new LDAPartialTopics(words_rhs, phi_rhs); System.out.println(lhs); System.out.println(rhs); System.out.println(rhs.merge(lhs)); System.out.println(lhs.merge(rhs)); System.out.println(lhs.merge(rhs).toTopics()); } } ================================================ FILE: src/main/java/com/etsy/conjecture/topics/lda/LDARandomTopics.java ================================================ package com.etsy.conjecture.topics.lda; import java.io.Serializable; import java.util.Random; public class LDARandomTopics implements LDATopics, Serializable { private static final long serialVersionUID = -3258304331549481829L; int num_topics; int dict_size; LDADict dict; Random rnd = new Random(); public LDARandomTopics(LDADict dict, int num_topics) { this.num_topics = num_topics; this.dict_size = dict.size(); this.dict = dict; } public double wordProb(int word, int topic) { // not gonna normalize or anything, central limit theorem bro. rnd.setSeed(topic * dict_size + word); double mean = 1.0 / dict_size; // So if theres 100 words, return something between 0.005 and 0.015 double rand = Math.max(0.0, mean + (rnd.nextBoolean() ? 1 : -1) * rnd.nextDouble() * (mean / 2)); return rand; } public int numTopics() { return num_topics; } public int dictSize() { return dict_size; } public LDADict getDict() { return dict; } public void setDict(LDADict d) { dict = d; dict_size = d.size(); } } ================================================ FILE: src/main/java/com/etsy/conjecture/topics/lda/LDASparseTopics.java ================================================ package com.etsy.conjecture.topics.lda; import java.io.Serializable; import java.util.Map; public class LDASparseTopics implements LDATopics, Serializable { private static final long serialVersionUID = 4878060449289865652L; int K; Map prob; LDADict dict; public LDASparseTopics(int K, Map prob) { this.prob = prob; this.K = K; } public void setDict(LDADict dict_) { dict = dict_; } public LDADict getDict() { return dict; } public double wordProb(int word, int topic) { int key = word * K + topic; if (prob.containsKey(key)) { return prob.get(key); } else { return 0.00000001; } } public int numTopics() { return K; } public int dictSize() { return dict.size(); } } ================================================ FILE: src/main/java/com/etsy/conjecture/topics/lda/LDATopics.java ================================================ package com.etsy.conjecture.topics.lda; import java.io.Serializable; public interface LDATopics extends Serializable { public void setDict(LDADict dict) throws Exception; public LDADict getDict(); public double wordProb(int word, int topic); public int numTopics(); public int dictSize(); } ================================================ FILE: src/main/java/com/etsy/conjecture/topics/lda/LDAUtils.java ================================================ package com.etsy.conjecture.topics.lda; import java.io.Serializable; public class LDAUtils implements Serializable { private static final long serialVersionUID = -1142647262716539345L; public static double digamma(double x) { if (x > 6.0) { double x2 = x * x; double x4 = x2 * x2; double x6 = x2 * x4; double x8 = x4 * x4; double x10 = x6 * x4; double x12 = x6 * x6; double x14 = x10 * x4; return Math.log(x) - 1.0 / (2 * x) - 1.0 / (12 * x2) - 1.0 / (120 * x4) - 1.0 / (252 * x6) + 1.0 / (240 * x8) - 5.0 / (660 * x10) + 691.0 / (32760 * x12) - 1.0 / (12 * x14); } else { return digamma(x + 1.0) - (1.0 / x); } } public static double logSumExp(double a, double b) { double x = (a < b) ? a : b; double y = (a < b) ? b : a; if (y - x > 50) { return y; } else { return x + Math.log(1.0 + Math.exp(y - x)); } } } ================================================ FILE: src/main/scala/com/etsy/conjecture/VWReader.scala ================================================ package com.etsy.conjecture import cascading.pipe.Pipe import cascading.flow.FlowDef import com.twitter.scalding._ import com.etsy.conjecture.data._ import scala.collection.generic import scala.util.matching.Regex // Input: line file in VW format // Writes MulticlassLabeledInstances in JSON trait VWReader { import Dsl._ def parse(input: String): MulticlassLabeledInstance = { // parse header val a = input.split("""\s*\|""").toList var b = a(0).split("""\s+""").toList var label = b(0) var importance = 1.0 var tag = "" try { if (b.length > 1) importance = b(1).toDouble if (b.length > 2) tag = b(2) } catch { case e: Exception => println("Ignoring header") } // create inst with header info val instObj = new MulticlassLabeledInstance(label) instObj.setId(tag) // parse remainder val remainder = input.split("\\s+").toList val pipePattern = """(.*)\|(.*)""" val pipeReg = (pipePattern).r var pastHeader = false var namespace = "" remainder.map { token: String => if (pipeReg.pattern.matcher(token).matches) { val pipeReg(before, after) = token if (pastHeader) addFeature(instObj, before, namespace) namespace = extractNamespace(after) // will be "" if no namespace pastHeader = true } else { if (pastHeader) addFeature(instObj, token, namespace) } } instObj } def extractNamespace(token: String): String = { val pairReg = ("""(.+)\:(.+)""").r var namespace = token if (pairReg.pattern.matcher(token).matches) { val pairReg(term, value) = token namespace = term // TODO: return weight for namespace when models can handle that } namespace } def setId(instObj: MulticlassLabeledInstance, id: String): Boolean = { instObj.setId(id) true } def setImportance(instObj: MulticlassLabeledInstance, token: String): Boolean = { // TODO: set importance weighting here once we support it true } def addFeature(instObj: MulticlassLabeledInstance, token: String, namespace: String) { val pairPattern = """(.+)\:(.+)""" val pairReg = (pairPattern).r if (token == "") return try { if (pairReg.pattern.matcher(token).matches) { val pairReg(term, value) = token if (namespace == "") instObj.addTerm(term, value.toDouble) // catch numberFormatException else instObj.addTermWithNamespace(term, namespace, value.toDouble) } else { if (namespace == "") instObj.addTerm(token) else instObj.addTermWithNamespace(token, namespace) } } catch { case e: Exception => println("Ignore line: " + token) } } } ================================================ FILE: src/main/scala/com/etsy/conjecture/demo/DemoLinearHyperparameterSearch.scala ================================================ package com.etsy.scalding.jobs.conjecture import scala.util.Random import com.etsy.conjecture.scalding.util._ import com.etsy.conjecture.data._ import com.etsy.conjecture.model._ import com.twitter.scalding._ import cascading.tuple.Fields /** * An example of a custom Job that would run the hyperparameter searcher for a Binary model * Takes command line arguments: * input: Path to training/testing data * out_dir: Path to output directory * num_trials: The number of models to train over random settings * model: Type of linear model to use for training */ class DemoLinearHyperparameterSearch(args : Args) extends BaseGridSearcher(args) { /* * Define the settings to be optimized as below. * DynamicOptions are generic type containers for Args that can perform various metric caluations. * All command line parameters given to this job are automatically added. * DynamicOption takes output name of param and the default value of the param. */ class DefaultClassifierOptions extends DynamicOptions(args) { val laplace = new DynamicOption("laplace", 0.0) val gauss = new DynamicOption("gauss", 0.0) val rate = new DynamicOption("rate", 0.1) val numIters = new DynamicOption("iter", 5) } //Define your parameters to optimize val opts = new DefaultClassifierOptions //For each parameter you wish to optimize and defined about, create a hyperparameter instance with the dynamic container and the sampler type val parameters: Seq[HyperParameter[_]] = { val laplace = new HyperParameter(opts.laplace, new LogUniformDoubleSampler(1e-8, 1e-1)) val gauss = new HyperParameter(opts.gauss, new LogUniformDoubleSampler(1e-8, 1e-1)) val rate = new HyperParameter(opts.rate, new SampleFromSeq(List(.01, .001, .0001, .00001, .000001))) val iters = new HyperParameter(opts.numIters, new SampleFromSeq(List(3, 5))) Seq(laplace, gauss, rate, iters) } //Define model type, Binary, Multiclass, or Regression val searcher = new BinaryHyperparameterSearcher(opts, parameters, numTrials) // Call hyperparameter search to run and write to given file location val (results, report) = searcher.search(instances, instance_field) //Write to the file of your choice. results.write(SequenceFile(out_dir + "/trialSummary")) report.write(SequenceFile(out_dir + "/parameterReport")) } ================================================ FILE: src/main/scala/com/etsy/conjecture/demo/IrisDataToMulticlassLabeledInstances.scala ================================================ package com.etsy.conjecture.demo import com.twitter.scalding._ import com.etsy.conjecture.data._ class IrisDataToMulticlassLabeledInstances(args: Args) extends Job(args) { // This class just converts the tsv of iris data to a sequence file of multiclass labeled instances // which the AdHocClassifier can then use to train. // Note that for a dataset of this size, the use of a hadoop job is overkill, this is for demonstration // puroses. TextLine(args.getOrElse("input_file", "iris.tsv")) .mapTo('instance) { l: String => val names = Array("sepal_length", "sepal_width", "petal_length", "petal_width") val parts = l.split("\t") val instance = new MulticlassLabeledInstance(parts(4)) (0 until 4).foreach { i => instance.setCoordinate(names(i), parts(i).toDouble) } instance } .write(SequenceFile(args.getOrElse("output_file", "iris_instances"))) } ================================================ FILE: src/main/scala/com/etsy/conjecture/demo/LearnMulticlassClassifier.scala ================================================ package com.etsy.conjecture.demo import com.twitter.scalding._ import com.etsy.conjecture.scalding.evaluate.{ MulticlassCrossValidator, MulticlassEvaluator } import com.etsy.conjecture.scalding.train.MulticlassModelTrainer import com.etsy.conjecture.data.{ MulticlassLabel, MulticlassLabeledInstance } import com.etsy.conjecture.model.UpdateableMulticlassLinearModel import com.google.gson.Gson import cascading.tuple.Fields class LearnMulticlassClassifier(args: Args) extends Job(args) { val input = args("input") val out_dir = args.getOrElse("output", "multiclass_classifier") val class_names = args("class_names").split(",") val folds = args.getOrElse("folds", "0").toInt // Let the user configure the field names on the command line. val data_field_names = args.getOrElse("data_fields", "instance").split(",") val data_fields = data_field_names.tail.foldLeft(new Fields(data_field_names.head)) { (x, y) => x.append(new Fields(y)) } val instance_field = Symbol(args.getOrElse("instance_field", "instance")) val instances = SequenceFile(input, data_fields).project(instance_field) val model_pipe = new MulticlassModelTrainer(args, class_names) .train(instances, instance_field, 'model) model_pipe .write(SequenceFile(out_dir + "/model")) .mapTo('model -> 'json) { x: UpdateableMulticlassLinearModel => new Gson().toJson(x) } .write(Tsv(out_dir + "/model_json")) if (folds > 0) { val eval_pred = new MulticlassCrossValidator(args, folds, class_names) .crossValidateWithPredictions(instances, instance_field, 'pred) eval_pred._1 .write(Tsv(out_dir + "/xval")) eval_pred._2 .write(Tsv(out_dir + "/pred")) } } ================================================ FILE: src/main/scala/com/etsy/conjecture/scalding/ALSJob.scala ================================================ package com.etsy.conjecture.scalding import cascading.pipe.Pipe import cascading.pipe.joiner.InnerJoin import com.twitter.scalding.{Args, Job, Mode, SequenceFile} import org.apache.commons.math3.linear._ /** * An abstract job class to implement alternating least squares for matrix factorization. * Since the method is iterative, this job overrides job.next rather than trying to * build a single massive cascading flow. This means that the job is more robust to failure, and * also doesn't crash the cascading planner with a giant graph. * * The concrete job class which extends this just has to override the function s() which returns a pipe * having fields ('row, 'col, 'value) representing the matrix to factorize. This is only computed on the * first iteration, and then written to disk. Therefore the function should be self contained, so that the * job doesnt try to do pointless work on every iteration. * * There are some other fields which the child class can override in order to get specific behavior from * the method: * * - zero_weight: the weight of zeros in the matrix, where nonzeros are given weight 1. * - norm_constraint: whether to force the norms of rows of the factors to 1 (useful for doing LSH for max-product search). * - lambda_row, lambda_col: L2 regularization parameters on the two factors. * */ abstract class ALSJob[R, C](args : Args) extends Job(args) { override def config: Map[AnyRef, AnyRef] = super.config + ("mapred.child.java.opts" -> "-Xmx3G") // Dimension of latent factors. val n = args.getOrElse("dim", "200").toInt val iter = args.getOrElse("iter", "0").toInt val max_iter = args.getOrElse("max_iter", "15").toInt val parallelism = args.getOrElse("parallelism", "500").toInt val base_dir = args.getOrElse("base_dir", "als") // The weight of zero terms in the matrices. val zero_weight = 0.001 // data for s matrix, must have fields ('row, 'col, 'value) def s() : Pipe def norm_constraint : Boolean = false def lambda_row : Double = 0.0f def lambda_col : Double = 0.0f val incremental = args.boolean("incremental") // allow overriding input and output paths. val input_u_path = args.getOrElse("input_u_path", base_dir+"/U/"+(iter-1)) val output_u_path = args.getOrElse("output_u_path", base_dir+"/U/"+iter) val output_v_path = args.getOrElse("output_v_path", base_dir+"/V/"+iter) // technique to initialize the vector def initial_vector(row : R) : RealVector = { val rand = new scala.util.Random(row.hashCode) val vec = MatrixUtils.createRealVector((0 until n).map{i => rand.nextGaussian}.toArray) vec.mapDivide(vec.getNorm) } val S = if(iter == 0 || args.boolean("update_matrix")) { s().project('row, 'col, 'value).write(SequenceFile(base_dir + "/S")) } else { SequenceFile(base_dir+"/S", ('row, 'col, 'value)).read } if(iter == 0 && !incremental) { // Initial item factors. S .groupBy('row){_.size('count)} .map('row -> 'u_vec)(initial_vector) .project('row, 'u_vec) .write(SequenceFile(base_dir+"/U/0")) } else { // Perform iteration of dual alternating least squares. val U = SequenceFile(input_u_path, ('row, 'u_vec)).read // -- Update V first. // Compute U'U val UU = U.mapTo('u_vec -> 'UU){u : RealVector => u.outerProduct(u)} .groupAll{_.reduce[RealMatrix]('UU){(a, b) => a.add(b)}} val V = S.joinWithSmaller('row -> 'row, U, new InnerJoin(), parallelism) .groupBy('col){_.toList[(RealVector, Double)](('u_vec, 'value) -> 'u_list).reducers(parallelism).forceToReducers} .crossWithTiny(UU) .mapTo(('col, 'u_list, 'UU) -> ('col, 'v_vec)){ x : (C, List[(RealVector, Double)], RealMatrix) => val col_id = x._1 var XX = x._3.scalarMultiply(zero_weight) var Xy = x._2.view.map{t => t._1.mapMultiply(t._2)}.reduce{(a,b) => a.add(b)}.mapMultiply(1.0 + zero_weight) x._2.foreach{t => XX = XX.add(t._1.outerProduct(t._1))} val lambda = if(norm_constraint) compute_lambda(XX, Xy) else lambda_col val res = new LUDecomposition(XX.add(MatrixUtils.createRealIdentityMatrix(XX.getRowDimension).scalarMultiply(lambda))).getSolver.getInverse.operate(Xy) (col_id, res) } .write(SequenceFile(output_v_path)) // -- Finally update U. val VV = V.mapTo('v_vec -> 'VV){u : RealVector => u.outerProduct(u)} .groupAll{_.reduce[RealMatrix]('VV){(a, b) => a.add(b)}} S .joinWithSmaller('col -> 'col, V, new InnerJoin(), parallelism) .groupBy('row){_.toList[(RealVector, Double)](('v_vec, 'value) -> 'v_list).reducers(parallelism).forceToReducers} .crossWithTiny(VV) .mapTo(('row, 'v_list, 'VV) -> ('row, 'u_vec)){ x : (R, List[(RealVector, Double)], RealMatrix) => val row_id = x._1 var XX = x._3.scalarMultiply(zero_weight) var Xy = x._2.view.map{t => t._1.mapMultiply(t._2)}.reduce{(a,b) => a.add(b)}.mapMultiply(1.0 + zero_weight) x._2.foreach{t => XX = XX.add(t._1.outerProduct(t._1))} val lambda = if(norm_constraint) compute_lambda(XX, Xy) else lambda_row val res = new LUDecomposition(XX.add(MatrixUtils.createRealIdentityMatrix(XX.getRowDimension).scalarMultiply(lambda))).getSolver.getInverse.operate(Xy) (row_id, res) } .write(SequenceFile(output_u_path)) } // for the norm constrained version, compute the lambda necessary so that the output vector has unit norm. def compute_lambda(XX : RealMatrix, Xy : RealVector) : Double = { val eigen = new EigenDecomposition(XX) val u = eigen.getVT.operate(Xy) // approximate the lagrange multiplier var lambda_max = math.sqrt(u.dotProduct(u)) var lambda_min = -eigen.getRealEigenvalues.min+0.000000001 var norm_max = (0 until u.getDimension).map{i => val ui = u.getEntry(i); val ei = eigen.getRealEigenvalue(i); ui*ui / ((ei+lambda_max)*(ei+lambda_max))}.sum var norm_min = (0 until u.getDimension).map{i => val ui = u.getEntry(i); val ei = eigen.getRealEigenvalue(i); ui*ui / ((ei+lambda_min)*(ei+lambda_min))}.sum while(math.abs(norm_max - norm_min) > 0.0001) { val lambda_mid = (lambda_max + lambda_min) / 2.0 val norm_mid = (0 until u.getDimension).map{i => val ui = u.getEntry(i); val ei = eigen.getRealEigenvalue(i); ui*ui / ((ei+lambda_mid)*(ei+lambda_mid))}.sum if(norm_mid < 1) { lambda_max = lambda_mid norm_max = norm_mid } else { lambda_min = lambda_mid norm_min = norm_mid } } val lambda = (lambda_max + lambda_min) / 2 lambda } override def next : Option[Job] = { val new_args = args + ("iter", Some((iter+1).toString)) if(iter < max_iter && !incremental) { Some(clone(new_args)) } else { None } } } ================================================ FILE: src/main/scala/com/etsy/conjecture/scalding/FastKNN.scala ================================================ package com.etsy.conjecture.scalding import collection.mutable.PriorityQueue import com.twitter.scalding._ import cascading.pipe.Pipe import cascading.pipe.joiner.InnerJoin import org.apache.commons.math3.linear.{MatrixUtils, RealVector} object FastKNN extends Serializable { import com.twitter.scalding.Dsl._ // The basic idea is that we do KNN on arbitrary types. // These can be objects containing e.g., identifiers (user_id etc) and also correspond to some point in a metric space. // Examples are objects like (user_id, vector from matrix factorization model). // Typically when we do the KNN procedure, we dont actually care about returning the entire object for all the neighbors, // but only the list of the ids, and their distances. // E.g., we would return the list of (user_id, distance) rather than (user_id, vector, distance). // This allows having larger lists of stuff in ram since the vector etc may be large. // The main entry point for knn in a single pipe. // X: Type of the element on which the distance is defined (the thing in the vec_field). // Y: Type of id for the element (thing in the id_field). // p: Pipe of stuff to knn // id_field: Field name for id // vec_field: field name for vec. // neighb_field: field name for result (neighbors). // k: the k from knn. // dist: the distance function for X. If the thing you give isnt a real distance function then probably this method will give you garbage results. // init_num_centers: Number of blocks to partition the data into, should be probably around sqrt(n). // bin_per_point: how many blocks to put each point into (increasing quality of the approximation). def knn[X, Y](p : Pipe, id_field : Symbol, vec_field : Symbol, neighb_field : Symbol, k : Int, dist : (X, X) => Double, init_num_centers : Int = 10000, bins_per_point : Int = 5, max_bin_size : Int = 20000) : Pipe = { val centers = initialize_bins[X](p, id_field, vec_field, dist, init_num_centers, bins_per_point, max_bin_size) // Do knn in each cluster, and aggregate. construct_bins[X, Y](p, id_field, vec_field, 'list, centers, bins_per_point, max_bin_size, dist) .filter('count){c : Int => c <= max_bin_size} .flatMapTo('list -> (id_field, neighb_field)){l : List[(Y, X)] => println(l.size) l.view.map{t => (t._1, knn_id[X, Y](t._2, l, k+1, dist).filter{_._1 != t._1})} } .groupBy(id_field){_.reduce[List[(Y, Double)]](neighb_field){(a, b) => (a++b).groupBy{_._1}.toList.map{t => (t._1, t._2.map{_._2}.min)}.sortBy{_._2}.take(k)}.reducers(1000).forceToReducers} .project(id_field, neighb_field) } // The entry point for the 2 pipe version of knn. // Z is the type for the id field of the candidates. def knn2[X, Y, Z](targets : Pipe, target_id_field : Symbol, target_vec_field : Symbol, candidates : Pipe, candidate_id_field : Symbol, candidate_vec_field : Symbol, neighb_field : Symbol, k : Int, dist : (X, X) => Double, init_num_centers : Int, bins_per_point : Int, max_bin_size : Int) : Pipe = { // Tesselate the candidates. val candidate_centers = initialize_bins[X](candidates, candidate_id_field, candidate_vec_field, dist, init_num_centers, 1, max_bin_size) val candidate_assignments = construct_bins[X, Y](candidates, candidate_id_field, candidate_vec_field, 'candidate_list, candidate_centers, 1, max_bin_size, dist) // Assign targets to same bins as candidates. val target_assignments = assign_bins[X](targets, target_id_field, target_vec_field, candidate_centers, bins_per_point, dist) // Replicate the candidates, and fragment the targets. val bin_replicates = target_assignments .groupBy('bin){_.size('count)} .map('count -> 'num_fragments){c : Int => 1 + (c / max_bin_size)} .groupAll{_.toList[(Int, Int)](('bin, 'num_fragments) -> 'bin_replicates)} .mapTo('bin_replicates -> 'bin_replicates){l : List[(Int, Int)] => l.toMap} val targets_fragmented = target_assignments .crossWithTiny(bin_replicates) .map((target_id_field, 'bin, 'bin_replicates) -> ('rep_bin, 'rep)){x : (Z, Int, Map[Int, Int]) => (x._2, math.abs(x._1.hashCode) % x._3.getOrElse(x._2, 1))} .groupBy('rep_bin, 'rep){_.toList[(Z, X)]((target_id_field, target_vec_field) -> 'target_list).reducers(1000)} val candidates_replicated = candidate_assignments .crossWithTiny(bin_replicates) .flatMap(('bin, 'bin_replicates) -> ('rep_bin, 'rep)){x : (Int, Map[Int, Int]) => (0 until x._2.getOrElse(x._1, 1)).map{i => (x._1, i)}} .project('rep_bin, 'rep, 'candidate_list) // Do knn in each cluster, and aggregate. candidates_replicated .joinWithSmaller(('rep_bin, 'rep) -> ('rep_bin, 'rep), targets_fragmented, new InnerJoin(), 1000) .flatMapTo(('target_list, 'candidate_list) -> (target_id_field, neighb_field)){x : (List[(Z, X)], List[(Y, X)]) => println(x._1.size + " " + x._2.size) x._1.view.map{t => (t._1, knn_id[X, Y](t._2, x._2, k, dist))} } .groupBy(target_id_field){_.reduce[List[(Y, Double)]](neighb_field){(a, b) => (a++b).groupBy{_._1}.toList.map{t => (t._1, t._2.map{_._2}.min)}.sortBy{_._2}.take(k)}.reducers(1000)} .project(target_id_field, neighb_field) } // Return the ids of the closest elements to the target. // X is the type of the element on which the distance is defined. // Y is the type of the identifier for each element. def knn_id[X, Y](target : X, candidates : List[(Y, X)], K : Int, dist : (X, X) => Double) : List[(Y, Double)] = { if(K > 250) { candidates.map{s => (s._1, dist(target, s._2))}.sortBy{_._2}.take(K) } else { val q = new PriorityQueue[(Y, Double)]()(Ordering.by[(Y, Double), Double](_._2)) var worst = 0.0 var size = 0 candidates.foreach{s => val ds = dist(target, s._2) if(size < K || ds < worst) { size += 1 q.enqueue((s._1, ds)) if(size > K) { q.dequeue size -= 1 } worst = q.head._2 } } q.toList.sortBy{_._2} } } // Return the indices of the closest elements. def knn_idx[X](vec : X, l : List[X], K : Int, dist : (X, X) => Double) : List[Int] = { val q = new PriorityQueue[(Int, Double)]()(Ordering.by[(Int, Double), Double](_._2)) var worst = 0.0 var size = 0 var idx = 0 l.foreach{r : X => val di = dist(vec, r) if(size < K || di < worst) { size += 1 q.enqueue((idx, di)) if(size > K) { q.dequeue size -= 1 } worst = q.head._2 } idx += 1 } q.toList.sortBy{_._2}.map{_._1} } def initialize_bins[X](p : Pipe, id_field : Symbol, vec_field : Symbol, dist : (X, X) => Double, init_num_centers : Int, bins_per_point : Int, max_bin_size : Int) : Pipe = { // Choose init_num_centers points at random. val centers = p .map(vec_field -> 'rand){r : X => new scala.util.Random(r.toString.hashCode).nextDouble} .groupRandomly(math.min(1000, init_num_centers)){_.sortWithTake[(X, Double)]((vec_field, 'rand) -> 'centers, 1 + (init_num_centers / 1000)){(a, b) => a._2 > b._2}} .groupAll{_.reduce[List[(X, Double)]]('centers){(a, b) => a++b}} .mapTo('centers -> 'centers){l : List[(X, Double)] => l.sortBy{-_._2}.take(init_num_centers).map{_._1}} val centers_new = p.crossWithTiny(centers) .flatMap(('centers, vec_field) -> 'bin){x : (List[X], X) => knn_idx[X](x._2, x._1, bins_per_point, dist)} .project('bin, vec_field) .map(vec_field -> 'rand){r : X => new scala.util.Random((r.toString+"foo").hashCode).nextDouble} .groupBy('bin){_.size('count).sortWithTake[(X, Double)]((vec_field, 'rand) -> 'centers, 1000){(a, b) => a._2 > b._2}} .filter('count){c : Int => c > max_bin_size} .mapTo(('centers, 'count) -> 'centers_new){l : (List[(X, Double)], Int) => val md = l._1.maxBy{_._2}._2 l._1.sortBy{t => val d = t._2 / md; -d * (1-d)}.take(l._2 / max_bin_size).map{_._1} } .groupAll{_.reduce[List[(X, Double)]]('centers_new){(a, b) => a++b}} centers.crossWithTiny(centers_new) .mapTo(('centers, 'centers_new) -> 'centers){x : (List[X], List[X]) => x._1 ++ x._2} } def assign_bins[X](p : Pipe, id_field : Symbol, vec_field : Symbol, centers : Pipe, bins_per_point : Int, dist : (X, X) => Double) : Pipe = { // Make assignments to clusters. p.crossWithTiny(centers) .flatMap((vec_field, 'centers) -> 'bin){x : (X, List[X]) => knn_idx[X](x._1, x._2, bins_per_point, dist)} .discard('centers) } def construct_bins[X, Y](p : Pipe, id_field : Symbol, vec_field : Symbol, list_field : Symbol, centers : Pipe, bins_per_point : Int, max_bin_size : Int, dist : (X, X) => Double) : Pipe = { assign_bins[X](p, id_field, vec_field, centers, bins_per_point, dist) .groupBy('bin){ //_.toList[(Y, X)]((id_field, vec_field) -> list_field) _.sortWithTake[(Y, X)]((id_field, vec_field) -> list_field, max_bin_size){(a, b) => false} .size('count) .reducers(1000) } .project('bin, 'count, list_field) } } ================================================ FILE: src/main/scala/com/etsy/conjecture/scalding/LSH.scala ================================================ package com.etsy.conjecture.scalding import collection.mutable.PriorityQueue import cascading.pipe.Pipe import cascading.pipe.joiner.InnerJoin import org.apache.commons.math3.linear.RealVector /** * Class provides functions for doing approximate K-nearest neighbors. * hashes : The number of times to hash. * planes : The number of dividing planes (also bits in the hash). * max_bin_size : The max size for a hash bin to be considered (we do exact knn in each bin, so large ones will increase computation time). * parallelism : How many reducers to use for critical sections. * defaults are sane for most problems. * more hashes = more chance for true knn to be in the same hash bin as the target, but also means more computation. * more planes = less items in each hash bucket, which improves computation but also could degrade approximation quality. */ class LSH(val hashes : Int = 50, val planes : Int = 12, val max_bin_size : Int = 10000, val parallelism : Int = 500) extends Serializable { // import neede to write scalding-like code. import com.twitter.scalding.Dsl._ /** * Just a class to hold an id and a vector together. */ class Point[T](val id : T, val vector : RealVector) extends Serializable {} /** * Brute force knn for inside each hash bin. * Works faster than just using obvious scala ways (map/sortBy etc). */ def findKnn[T](vec : RealVector, points : Iterable[Point[T]], K : Int) : List[(Point[T], Double)] = { val q = new PriorityQueue[(Point[T], Double)]()(Ordering.by[(Point[T], Double), Double](_._2)) var worst = 0.0 var size = 0 points.foreach{p : Point[T] => val dist = p.vector.getDistance(vec) if(size < K || dist < worst) { size += 1 q.enqueue((p, dist)) if(size > K) { q.dequeue size -= 1 } worst = q.head._2 } } q.toList.sortBy{_._2} } /** * Hash repeatedly by dividing the space along origin-containing planes. * v : The vector to hash. * output is the list of hashes, each having its index as part of the value. */ def hash(v : RealVector) : IndexedSeq[Long] = { (0 until hashes).map{h => (0 until planes).map{i => val r = new scala.util.Random(i+1000*h) // random suck with lil seeds. val d = v.toArray.map{_*r.nextGaussian}.sum if(d > 0.0) 1L << i else 0L }.sum + (h.toLong << planes) } } /** * Forms hash bins from a single pipe of vectors and ids. */ def form_bins[I](p : Pipe, id_field : Symbol, vec_field : Symbol, bin_field : Symbol, hash_field : Symbol) : Pipe = { p .map((id_field, vec_field) -> 'point){x : (I, RealVector) => new Point[I](x._1, x._2)} .flatMap(vec_field -> hash_field){v : RealVector => hash(v)} .project('point, hash_field) .groupBy(hash_field){ _.size('count) .sortWithTake[Point[I]]('point -> bin_field, max_bin_size){(a,b) => false} .reducers(parallelism) .forceToReducers } .filter('count){c : Int => c <= max_bin_size} .project(hash_field, bin_field) } /** * Single pipe version of knn. * Finds knn of each element in the pipe (i.e., every element is both a target and a candidate neighbor) * A thing isnt its own nearest neighbor. * I is the type of id used. */ def knn[I](p : Pipe, id_field : Symbol, vec_field : Symbol, neighbors_field : Symbol, K : Int) : Pipe = { form_bins[I](p, id_field, vec_field, 'bin, 'hash) .flatMapTo('bin -> (id_field, neighbors_field)){ bin : List[Point[I]] => bin.view.map{p => (p.id, findKnn[I](p.vector, bin, K+1).filter{_._1.id != p.id}.map{t => (t._1.id, t._2)}) // (id, distance) } } // - aggregate knn across hash bins. .groupBy(id_field) { _.reduce[List[(I, Double)]](neighbors_field){(a, b) => (a ++ b).groupBy{_._1}.mapValues{_.head._2}.toList.sortBy{_._2}.take(K)} .forceToReducers .reducers(parallelism) } .project(id_field, neighbors_field) } /** * Two pipe version of knn. * First pipe is targets (things we find the knn for) ids are of type I * Second pipe is candidates (things that can be the knn) ids are of type J * A thing can be its own neighbor if its in both pipes. */ def knn[I,J](targets : Pipe, target_id_field : Symbol, target_vec_field : Symbol, candidates : Pipe, candidate_id_field : Symbol, candidate_vec_field : Symbol, neighbors_field : Symbol, K : Int) : Pipe = { form_bins[I](targets, target_id_field, target_vec_field, 'target_bin, 'hash) .joinWithSmaller('hash -> 'hash, form_bins[J](candidates, candidate_id_field, candidate_vec_field, 'candidate_bin, 'hash), new InnerJoin(), parallelism) .flatMapTo(('target_bin, 'candidate_bin) -> (target_id_field, neighbors_field)){ x : (List[Point[I]], List[Point[J]]) => x._1.view.map{p => (p.id, findKnn[J](p.vector, x._2, K).map{t => (t._1.id, t._2)}) // (id, distance) } } // - aggregate knn across hash bins. .groupBy(target_id_field) { _.reduce[List[(J, Double)]](neighbors_field){(a, b) => (a ++ b).groupBy{_._1}.mapValues{_.head._2}.toList.sortBy{_._2}.take(K)} .forceToReducers .reducers(parallelism) } .project(target_id_field, neighbors_field) } } ================================================ FILE: src/main/scala/com/etsy/conjecture/scalding/NNMF.scala ================================================ package com.etsy.conjecture.scalding import org.apache.commons.math3.linear._ import com.etsy.scalding._ import com.twitter.algebird.Operators._ import com.twitter.scalding._ import cascading.flow.FlowDef import cascading.pipe.Pipe import cascading.pipe.joiner.InnerJoin import cascading.tuple.Fields object NNMF extends Serializable { import com.twitter.scalding.Dsl._ // based on http://research.microsoft.com/pubs/119077/dnmf.pdf /** * input: * A: a sparse matrix in the form ('row, 'col, 'val), with tuples of type (R, C, Double). * k: the dimension of the factorization. */ def initGaussian(A : Pipe, k : Int, reducers : Int = 500) : (Pipe, Pipe) = { val H0 = A.groupBy('row){_.size('count).reducers(reducers)} .map(() -> 'vec){_ : Unit => MatrixUtils.createRealVector((0 until k).map{i => math.random}.toArray)} .map(() -> 'bias){_ : Unit => math.random} .project('row, 'vec, 'bias) val W0 = A.groupBy('col){_.size('count).reducers(reducers)} .map(() -> 'vec){_ : Unit => MatrixUtils.createRealVector((0 until k).map{i => math.random}.toArray)} .map(() -> 'bias){_ : Unit => math.random} .project('col, 'vec, 'bias) (H0, W0) } /** * These functions embed bias terms for both factors into the original factorization. */ def createWVector(v : RealVector, b : Double) : RealVector = { v.append(MatrixUtils.createRealVector(Array(1.0, b))) } def createHVector(v : RealVector, b : Double) : RealVector = { v.append(MatrixUtils.createRealVector(Array(b, 1.0))) } def explodeWVector(u : RealVector) : (RealVector, Double) = { val d = u.getDimension (u.getSubVector(0, d - 2), u.getEntry(d - 1)) } def explodeHVector(u : RealVector) : (RealVector, Double) = { val d = u.getDimension (u.getSubVector(0, d - 2), u.getEntry(d - 2)) } /* * input: * A: a sparse matrix in the form ('row, 'col, 'val), with tuples of type (R, C, Double). * H: a dense matrix of ('row, 'vec, 'bias) * W: a dense matrix of ('col, 'vec, 'bias) * With W,H from initGaussian or a previous iteration. */ def updateGaussian(A : Pipe, H : Pipe, W : Pipe, reducers : Int = 500) : (Pipe, Pipe) = { // Note that row and column vectors are both represented as a RealVector which doesnt have an orientation. // Therefore whether it is a row or column will have to be inferred from context. // -- First update H. // W'W val WW = W.mapTo(('vec, 'bias) -> 'WW){v : (RealVector, Double) => val u = createWVector(v._1, v._2) u.outerProduct(u) } .groupAll{_.reduce[RealMatrix]('WW){(a, b) => a.add(b)}} // W'WH val WWH = H.crossWithTiny(WW) .map(('WW, 'vec, 'bias) -> 'vec_wwh){x : (RealMatrix, RealVector, Double) => x._1.operate(createHVector(x._2, x._3))} .project('row, 'vec_wwh) // W'A val WA = W.joinWithLarger('col -> 'col, A, new InnerJoin(), reducers) .map(('val, 'vec, 'bias) -> 'vec){x : (Double, RealVector, Double) => createWVector(x._2, x._3).mapMultiply(x._1)} .groupBy('row){_.reduce[RealVector]('vec -> 'vec_wa){(a, b) => a.add(b)}.reducers(reducers).forceToReducers} // Pointwise multiplier to old H val HM = WA.joinWithSmaller('row -> 'row, WWH, new InnerJoin(), reducers) .map(('vec_wa, 'vec_wwh) -> 'vec_mult){x : (RealVector, RealVector) => x._1.ebeDivide(x._2)} .map('vec_mult -> 'vec_mult){x : RealVector => MatrixUtils.createRealVector(x.toArray.map{i => if(i.isInfinite || i.isNaN) 1.0 else i})} .project('row, 'vec_mult) // new H. val H_ = H.joinWithSmaller('row -> 'row, HM, new InnerJoin(), reducers) .map(('vec, 'bias, 'vec_mult) -> ('vec, 'bias)){x : (RealVector, Double, RealVector) => explodeHVector(createHVector(x._1, x._2).ebeMultiply(x._3))} .project('row, 'vec, 'bias) // -- Then update W. // HH' val HH = H_.mapTo(('vec, 'bias) -> 'HH){v : (RealVector, Double) => val u = createHVector(v._1, v._2) u.outerProduct(u) } .groupAll{_.reduce[RealMatrix]('HH){(a, b) => a.add(b)}} // WHH' val WHH = W.crossWithTiny(HH) .map(('HH, 'vec, 'bias) -> 'vec_whh){x : (RealMatrix, RealVector, Double) => x._1.operate(createWVector(x._2, x._3))} .project('col, 'vec_whh) // AH' val AH = H_.joinWithLarger('row -> 'row, A, new InnerJoin(), reducers) .map(('val, 'vec, 'bias) -> 'vec){x : (Double, RealVector, Double) => createHVector(x._2, x._3).mapMultiply(x._1)} .groupBy('col){_.reduce[RealVector]('vec -> 'vec_ah){(a, b) => a.add(b)}.reducers(reducers).forceToReducers} // Pointwise multiplier to old W val WM = AH.joinWithSmaller('col -> 'col, WHH, new InnerJoin(), reducers) .map(('vec_ah, 'vec_whh) -> 'vec_mult){x : (RealVector, RealVector) => x._1.ebeDivide(x._2)} .map('vec_mult -> 'vec_mult){x : RealVector => MatrixUtils.createRealVector(x.toArray.map{i => if(i.isInfinite || i.isNaN) 1.0 else i})} .project('col, 'vec_mult) // new W. val W_ = W.joinWithSmaller('col -> 'col, WM, new InnerJoin(), reducers) .map(('vec, 'bias, 'vec_mult) -> ('vec, 'bias)){x : (RealVector, Double, RealVector) => explodeWVector(createWVector(x._1, x._2).ebeMultiply(x._3))} .project('col, 'vec, 'bias) (H_, W_) } /** * Possibly faster vec add that doesnt create a new object. */ def addTo(acc : RealVector, v : RealVector) : RealVector = { acc.combineToSelf(1.0, 1.0, v) acc } /** * Possibly faster matrix add that doesnt create a new object. */ def addTo(acc : RealMatrix, m : RealMatrix) : RealMatrix = { var r = 0 while(r < acc.getRowDimension) { var c = 0 while(c < acc.getColumnDimension) { acc.addToEntry(r, c, m.getEntry(r, c)) c += 1 } r += 1 } acc } /* * -- weighted version. * Optimizes weighted l2 loss, where zeros have weight 1, non zeros have weight 1+alpha. * input: * A: a sparse matrix in the form ('row, 'col, 'val), with tuples of type (R, C, Double). * H: a dense matrix of ('row, 'vec, 'bias) * W: a dense matrix of ('col, 'vec, 'bias) * alpha: alpha from loss function. * With W,H from initGaussian or a previous iteration. */ def updateGaussianWeighted(A : Pipe, H : Pipe, W : Pipe, alpha : Double, reducers : Int = 500) : (Pipe, Pipe) = { // -- First update H. // W'W val WW = W.mapTo(('vec, 'bias) -> 'WW){v : (RealVector, Double) => val u = createWVector(v._1, v._2) u.outerProduct(u) } .groupAll{_.reduce[RealMatrix]('WW){(a, b) => a.add(b)}} // W'WH val WWH = H.crossWithTiny(WW) .map(('vec, 'bias) -> 'vec){x : (RealVector, Double) => createHVector(x._1, x._2)} .map(('WW, 'vec) -> 'vec_wwh){x : (RealMatrix, RealVector) => x._1.operate(x._2)} .project('row, 'vec_wwh, 'vec) // W'A val WA = W.joinWithLarger('col -> 'col, A, new InnerJoin(), reducers) .map(('val, 'vec, 'bias) -> ('vec_wa, 'denom_vec)){x : (Double, RealVector, Double) => val v = createWVector(x._2, x._3) (v.mapMultiply(x._1), v) } .groupBy('row){ _.reduce[RealVector]('vec_wa){(a, b) => addTo(a, b)} .toList[RealVector]('denom_vec -> 'denom_vec_list) .reducers(reducers) //.forceToReducers } .project('row, 'vec_wa, 'denom_vec_list) // Pointwise multiplier to old H val HM = WA.joinWithSmaller('row -> 'row, WWH, new InnerJoin(), reducers) .map(('vec_wa, 'vec_wwh, 'vec, 'denom_vec_list) -> 'vec_mult){x : (RealVector, RealVector, RealVector, List[RealVector]) => val den_vec = x._4.tail.foldLeft(x._4.head.mapMultiply(x._4.head.dotProduct(x._3))){(a, b) => a.combineToSelf(1.0, b.dotProduct(x._3), b); a} val num = x._1.mapMultiply(1.0 + alpha) val den = x._2.add(den_vec.mapMultiply(alpha)) num.ebeDivide(den) } .map('vec_mult -> 'vec_mult){x : RealVector => MatrixUtils.createRealVector(x.toArray.map{i => if(i.isInfinite || i.isNaN) 1.0 else i})} .project('row, 'vec_mult) // new H. val H_ = H.joinWithSmaller('row -> 'row, HM, new InnerJoin(), reducers) .map(('vec, 'bias, 'vec_mult) -> ('vec, 'bias)){x : (RealVector, Double, RealVector) => explodeHVector(createHVector(x._1, x._2).ebeMultiply(x._3))} .project('row, 'vec, 'bias) // -- Then update W. // HH' val HH = H_.mapTo(('vec, 'bias) -> 'HH){v : (RealVector, Double) => val u = createHVector(v._1, v._2) u.outerProduct(u) } .groupAll{_.reduce[RealMatrix]('HH){(a, b) => a.add(b)}} // WHH' val WHH = W.crossWithTiny(HH) .map(('vec, 'bias) -> 'vec){x : (RealVector, Double) => createWVector(x._1, x._2)} .map(('HH, 'vec) -> 'vec_whh){x : (RealMatrix, RealVector) => x._1.operate(x._2)} .project('col, 'vec_whh, 'vec) // AH' val AH = H_.joinWithLarger('row -> 'row, A, new InnerJoin(), reducers) .map(('val, 'vec, 'bias) -> ('vec_ah, 'denom_vec)){x : (Double, RealVector, Double) => val v = createHVector(x._2, x._3) (v.mapMultiply(x._1), v) } .groupBy('col){ _.reduce[RealVector]('vec_ah){(a, b) => addTo(a, b)} .toList[RealVector]('denom_vec -> 'denom_vec_list) .reducers(reducers) //.forceToReducers } .project('col, 'vec_ah, 'denom_vec_list) // Pointwise multiplier to old W val WM = AH.joinWithSmaller('col -> 'col, WHH, new InnerJoin(), reducers) .map(('vec_ah, 'vec_whh, 'vec, 'denom_vec_list) -> 'vec_mult){x : (RealVector, RealVector, RealVector, List[RealVector]) => val den_vec = x._4.tail.foldLeft(x._4.head.mapMultiply(x._4.head.dotProduct(x._3))){(a, b) => a.combineToSelf(1.0, b.dotProduct(x._3), b); a} val num = x._1.mapMultiply(1.0 + alpha) val den = x._2.add(den_vec.mapMultiply(alpha)) x._1.ebeDivide(x._2) } .map('vec_mult -> 'vec_mult){x : RealVector => MatrixUtils.createRealVector(x.toArray.map{i => if(i.isInfinite || i.isNaN) 1.0 else i})} .project('col, 'vec_mult) // new W. val W_ = W.joinWithSmaller('col -> 'col, WM, new InnerJoin(), reducers) .map(('vec, 'bias, 'vec_mult) -> ('vec, 'bias)){x : (RealVector, Double, RealVector) => explodeWVector(createWVector(x._1, x._2).ebeMultiply(x._3))} .project('col, 'vec, 'bias) (H_, W_) } } ================================================ FILE: src/main/scala/com/etsy/conjecture/scalding/SVD.scala ================================================ package com.etsy.conjecture.scalding import org.apache.commons.math3.linear._ import cascading.pipe.Pipe import cascading.pipe.joiner.InnerJoin import cascading.tuple.Fields import scala.util.Random object SVD extends Serializable { import com.twitter.scalding.Dsl._ /** * based on http://amath.colorado.edu/faculty/martinss/Pubs/2012_halko_dissertation.pdf * page 121. * * generic parameters: * R: the type of the row name variable. * C: the type of the column name variable. * * input: * X: a sparse matrix in the form ('row, 'col, 'val), with tuples of type (R, C, Double). * d: number of principle components / singular values to compute * extra_power: whether to take the second power of XX' in order to improve the approximation quality. * reducers: how many reducers to use in the map-reduce stages. * * output: * (U, E, V) with * U : pipe of ('row, 'vec) where vec is a RealVector * E : pipe of 'E which is an Array[Double] of singular values. * V : pipe of ('col, 'vec) where vec is a RealVector * note that the vectors are rows of the matrices U and V, not the columns which correspond to the left and right singular vectors. */ def apply[R, C](X : Pipe, d : Int, extra_power : Boolean = true, reducers : Int = 500, no_power : Boolean = false) : (Pipe, Pipe, Pipe) = { // Sample the columns, into the thin matrix. val XS = X.groupBy('row){_.toList[(C, Double)](('col, 'val) -> 'list).reducers(reducers)} .map('list -> 'vec){l : List[(C, Double)] => val a = new Array[Double](d+10) l.foreach{i => val r = new Random(i._1.hashCode.toLong) (0 until (d+10)).foreach{j => a(j) += r.nextGaussian * i._2 } } MatrixUtils.createRealVector(a) } .project('row, 'vec) // Multiply by powers of XX'. This improves the approximation quality. val Y = if(!no_power) { val XXXS = X .joinWithSmaller('row -> 'row_, XS.rename('row -> 'row_), new InnerJoin(), reducers) .map(('val, 'vec) -> 'vec){x : (Double, RealVector) => x._2.mapMultiply(x._1)} .groupBy('col){_.reduce('vec -> 'vec){(a : RealVector, b : RealVector) => a.add(b)}.forceToReducers.reducers(reducers)} .joinWithSmaller('col -> 'col_, X.rename('col -> 'col_), new InnerJoin(), reducers) .map(('val, 'vec) -> 'vec){x : (Double, RealVector) => x._2.mapMultiply(x._1)} .groupBy('row){_.reduce('vec -> 'vec2){(a : RealVector, b : RealVector) => a.add(b)}.forceToReducers.reducers(reducers)} if(extra_power) { val XXXXXS = X .joinWithSmaller('row -> 'row_, XXXS.rename('row -> 'row_), new InnerJoin(), reducers) .map(('val, 'vec2) -> 'vec2){x : (Double, RealVector) => x._2.mapMultiply(x._1)} .groupBy('col){_.reduce('vec2 -> 'vec2){(a : RealVector, b : RealVector) => a.add(b)}.forceToReducers.reducers(reducers)} .joinWithSmaller('col -> 'col_, X.rename('col -> 'col_), new InnerJoin(), reducers) .map(('val, 'vec2) -> 'vec2){x : (Double, RealVector) => x._2.mapMultiply(x._1)} .groupBy('row){_.reduce('vec2 -> 'vec2){(a : RealVector, b : RealVector) => a.add(b)}.forceToReducers.reducers(reducers)} XS .joinWithSmaller('row -> 'row, XXXS, new InnerJoin(), reducers) .map(('vec, 'vec2) -> 'vec){x : (RealVector, RealVector) => x._1.append(x._2)} .project('row, 'vec) .joinWithSmaller('row -> 'row, XXXXXS, new InnerJoin(), reducers) .map(('vec, 'vec2) -> 'vec){x : (RealVector, RealVector) => x._1.append(x._2)} .project('row, 'vec) } else { XS .joinWithSmaller('row -> 'row, XXXS, new InnerJoin(), reducers) .map(('vec, 'vec2) -> 'vec){x : (RealVector, RealVector) => x._1.append(x._2)} .project('row, 'vec) } } else { XS } // What follows is a QR decomposition of Y. // Note: Y = QR means Y'Y = R'R so R = chol(Y'Y) val YY = Y.mapTo('vec -> 'mat){x : RealVector => x.outerProduct(x)} // Could rewrite addition function to act in-place on a or b here. .groupAll{_.reduce('mat -> 'mat){(a : RealMatrix, b : RealMatrix) => a.add(b)}} .mapTo('mat -> 'mat){m : RealMatrix => val chol = new CholeskyDecomposition(m) new LUDecomposition(chol.getL).getSolver.getInverse } // Determine Q = YR^{-1} val Q = Y.crossWithTiny(YY) .map(('vec, 'mat) -> 'vec){x : (RealVector, RealMatrix) => x._2.operate(x._1)} .project('row, 'vec) // B = X'Q val B = X.joinWithSmaller('row -> 'row, Q, new InnerJoin(), reducers) .map(('val, 'vec) -> 'vec){x : (Double, RealVector) => x._2.mapMultiply(x._1)} .groupBy('col){_.reduce('vec -> 'vec){(a : RealVector, b : RealVector) => a.combineToSelf(1, 1, b)}.reducers(reducers).forceToReducers} // Uee eig(B'B) to get at svd(B) -- B = m * d // RWR' = B'B -- R = d*d // want UEV = B -- U = m*d, V = d*d // so V'E^2V = B'B = RWR' so // V = R' // W = sqrt(E) // U = BE^{-1}V' val EB = B.mapTo('vec -> 'mat){x : RealVector => x.outerProduct(x)} // Same re: optimizing the addition to not create temp objects. .groupAll{_.reduce('mat -> 'mat){(a : RealMatrix, b : RealMatrix) => a.add(b)}} .mapTo('mat -> ('eigs, 'eigmat)){m : RealMatrix => val e = new EigenDecomposition(m) (e.getRealEigenvalues.map{ei => math.sqrt(ei)}, e.getVT) } val E = EB.project('eigs).map('eigs -> 'eigs){x : Array[Double] => x.take(d)} val U = Q.crossWithTiny(EB.project('eigmat)) .map(('vec, 'eigmat) -> 'vec){x : (RealVector, RealMatrix) => x._2.operate(x._1).getSubVector(0,d)} .project('row, 'vec) val V = B.crossWithTiny(EB) .map(('vec, 'eigmat, 'eigs) -> 'vec){x : (RealVector, RealMatrix, Array[Double]) => MatrixUtils.createRealDiagonalMatrix(x._3.map{1.0/ _}).operate(x._2.operate(x._1)).getSubVector(0,d) } .project('col, 'vec) (U, E, V) } } ================================================ FILE: src/main/scala/com/etsy/conjecture/scalding/evaluate/GenericCrossValidator.scala ================================================ package com.etsy.conjecture.scalding.evaluate import com.twitter.scalding._ import cascading.pipe.Pipe import cascading.tuple.{ Fields, TupleEntry, Tuple } import com.etsy.conjecture.data._ import com.etsy.conjecture.evaluation._ import com.etsy.conjecture.model._ import com.etsy.conjecture.scalding.train._ class GenericCrossValidator[L <: Label, M <: UpdateableModel[L, M], E <: ModelEvaluation[L]](val evaluator: GenericEvaluator[L], val builder: AbstractModelTrainer[L, M], val folds: Int, val salt: String = "") extends Serializable { import Dsl._ def crossValidateWithPredictions(pipe: Pipe, instanceField: Symbol, predictionField: Symbol, labelField: Symbol = '__actual): (Pipe, Pipe) = { val folded = pipe.map(instanceField -> '__fold) { li: LabeledInstance[L] => (li.getVector.hashCode.toString + salt).hashCode % folds } .forceToDisk val preds = (0 until folds).map { i: Int => predictFold(folded, '__model, instanceField, labelField, predictionField, i) } val eval = preds.map { i: Pipe => evaluator.evaluate(i, predictionField, labelField, '__eval) }.reduce { _ ++ _ } .groupAll { _.foldLeft('__eval -> '__eval)(new EvaluationAggregator[L]()) { (a: EvaluationAggregator[L], e: E) => a.add(e); a } } .map('__eval -> '__eval) { a: EvaluationAggregator[L] => a.toString } val preds_all = preds.reduce { _ ++ _ }.project(labelField, predictionField) (eval, preds_all) } // note that the models on each fold may be calibrated differently, which will mess up the AUC calculation. // this may not be a big problem though. def predictFold(folded: Pipe, modelField: Symbol, instanceField: Symbol, labelField: Symbol, predictionField: Symbol, fold: Int): Pipe = { val train_inst = folded.filter('__fold) { x: Int => x != fold } val test_inst = folded.filter('__fold) { x: Int => x == fold } val model = builder.train(train_inst, instanceField, modelField) evaluator.assign_predictions(test_inst, instanceField, labelField, model, modelField, predictionField) } def crossValidate(pipe: Pipe, instanceField: Symbol): Pipe = { crossValidateWithPredictions(pipe, instanceField, '__prediction, '__actual)._1 } def evaluateFold(folded: Pipe, modelField: Symbol, instanceField: Symbol, labelField: Symbol, evalField: Symbol, fold: Int): Pipe = { val train_inst = folded.filter('__fold) { x: Int => x != fold } val test_inst = folded.filter('__fold) { x: Int => x == fold } val model = builder.train(train_inst, instanceField, modelField) evaluator.evaluate(test_inst, instanceField, labelField, model, modelField, evalField) } } class BinaryCrossValidator(args: Args, folds: Int) extends GenericCrossValidator[BinaryLabel, UpdateableLinearModel[BinaryLabel], BinaryModelEvaluation]( new BinaryEvaluator(), new BinaryModelTrainer(args), folds, args.getOrElse("salt", "")) class RegressionCrossValidator(args: Args, folds: Int) extends GenericCrossValidator[RealValuedLabel, UpdateableLinearModel[RealValuedLabel], RegressionModelEvaluation]( new RegressionEvaluator(), new RegressionModelTrainer(args), folds, args.getOrElse("salt", "")) class MulticlassCrossValidator(args: Args, folds: Int, categories: Array[String]) extends GenericCrossValidator[MulticlassLabel, UpdateableMulticlassLinearModel, MulticlassModelEvaluation]( new MulticlassEvaluator(categories), new MulticlassModelTrainer(args, categories), folds, args.getOrElse("salt", "")) ================================================ FILE: src/main/scala/com/etsy/conjecture/scalding/evaluate/GenericEvaluator.scala ================================================ package com.etsy.conjecture.scalding.evaluate import com.twitter.scalding._ import com.etsy.conjecture._ import com.etsy.conjecture.data._ import com.etsy.conjecture.evaluation._ import com.etsy.conjecture.model._ import cascading.pipe.Pipe abstract class GenericEvaluator[L <: Label] extends Serializable { import Dsl._ def build(): ModelEvaluation[L] def evaluate(instance_pipe: Pipe, predict_field: Symbol, label_field: Symbol, evaluation_field: Symbol): Pipe = { val partialEval = '__partial_eval instance_pipe .map((label_field, predict_field) -> partialEval){ pair : (L, L) => val eval = build eval.add(pair._1, pair._2) eval } .groupAll{ _.reduce(partialEval -> evaluation_field){ (eval : ModelEvaluation[L], final_eval : ModelEvaluation[L]) => final_eval.merge(eval) final_eval } } .project(evaluation_field) } def evaluate(instance_pipe: Pipe, instance_field: Symbol, label_field: Symbol, model_pipe: Pipe, model_field: Symbol, evaluation_field: Symbol): Pipe = { val instances_with_predictions = assign_predictions(instance_pipe, instance_field, label_field, model_pipe, model_field, 'prediction) evaluate(instances_with_predictions, label_field, 'prediction, evaluation_field) } def assign_predictions(instance_pipe: Pipe, instance_field: Symbol, label_field: Symbol, model_pipe: Pipe, model_field: Symbol, prediction_field: Symbol = 'prediction) = { instance_pipe.crossWithTiny(model_pipe) .map((instance_field, model_field) -> (label_field, prediction_field)) { x: (LabeledInstance[L], Model[L]) => (x._1.getLabel, x._2.predict(x._1.getVector)) } .project(label_field, prediction_field) } } class BinaryEvaluator extends GenericEvaluator[BinaryLabel] { def build() = new BinaryModelEvaluation() } class MulticlassEvaluator(categories: Array[String]) extends GenericEvaluator[MulticlassLabel] { def build() = new MulticlassModelEvaluation(categories) } class RegressionEvaluator extends GenericEvaluator[RealValuedLabel] { def build() = new RegressionModelEvaluation() } ================================================ FILE: src/main/scala/com/etsy/conjecture/scalding/factorize/FactorizationTools.scala ================================================ package com.etsy.conjecture.scalding.factorize import cascading.pipe.Pipe import org.apache.commons.math3.linear._ import cascading.pipe.joiner.InnerJoin object FactorizationTools { def approxLeftFactorsLeastSquaresBinary(rightFactors : Pipe, id_sym : Symbol, right_vec_sym : Symbol, designMatrix : Pipe, left_id : Symbol, right_id : Symbol, left_vec_symbol : Symbol, spill_threshold : Int = 1000000, parallelism : Int = 1000) : Pipe = { import com.twitter.scalding.Dsl._ approxLeftFactorsLeastSquares(rightFactors, id_sym, right_vec_sym, designMatrix.insert('value, 1.0), left_id, right_id, 'value, left_vec_symbol, spill_threshold, parallelism) } def approxLeftFactorsLeastSquares(rightFactors : Pipe, id_sym : Symbol, right_vec_sym : Symbol, designMatrix : Pipe, left_id : Symbol, right_id : Symbol, value_sym : Symbol, left_vec_symbol : Symbol, spill_threshold : Int = 1000000, parallelism : Int = 1000) : Pipe = { import com.twitter.scalding.Dsl._ val inv_sym = 'inverse val inv_self_outer = rightFactors .mapTo(right_vec_sym -> right_vec_sym) { l : RealVector => l.outerProduct(l) } .groupAll{ _.reduce[RealMatrix](right_vec_sym){ (x, y) => x.add(y) } } .mapTo(right_vec_sym -> inv_sym) { ll : RealMatrix => new LUDecomposition(ll).getSolver.getInverse } val premultiplied_right_factors = rightFactors .crossWithTiny(inv_self_outer) .map((right_vec_sym, inv_sym) -> right_vec_sym) { x:(RealVector, RealMatrix) => x._2.operate(x._1) } .project(id_sym, right_vec_sym) designMatrix.joinWithSmaller(right_id -> id_sym, premultiplied_right_factors, new InnerJoin(), parallelism) // Save an alloc if we do the binary case. .map((right_vec_sym, value_sym) -> right_vec_sym) { x : (RealVector, Double) => if(x._2 == 1.0) x._1 else x._1.mapMultiply(x._2) } .groupBy(left_id) { _.reduce[RealVector](right_vec_sym -> left_vec_symbol){ (x, y) => x.combineToSelf(1, 1, y) } .reducers(parallelism) .spillThreshold(spill_threshold) } .project(left_id, left_vec_symbol) } } ================================================ FILE: src/main/scala/com/etsy/conjecture/scalding/train/AbstractModelTrainer.scala ================================================ package com.etsy.conjecture.scalding.train import cascading.pipe.Pipe import com.etsy.conjecture.data._ import com.etsy.conjecture.model._ import com.twitter.scalding._ trait AbstractModelTrainer[L <: Label, M <: UpdateableModel[L, M]] extends Serializable { def train(instances: Pipe, instanceField: Symbol, modelField: Symbol): Pipe def reTrain(instances: Pipe, instanceField: Symbol, model: Pipe, modelField: Symbol): Pipe } ================================================ FILE: src/main/scala/com/etsy/conjecture/scalding/train/BinaryModelTrainer.scala ================================================ package com.etsy.conjecture.scalding.train import cascading.pipe.Pipe import com.twitter.scalding._ import com.etsy.conjecture.data._ import com.etsy.conjecture.model._ import java.io.File import scala.io.Source class BinaryModelTrainer(args: Args) extends AbstractModelTrainer[BinaryLabel, UpdateableLinearModel[BinaryLabel]] with ModelTrainerStrategy[BinaryLabel, UpdateableLinearModel[BinaryLabel]] { /** * Number of iterations for sequential gradient descent */ var iters = args.getOrElse("iters", "1").toInt override def getIters: Int = iters /** * What type of linear model should be used? * Options are: * 1. perceptron * 2. linear_svm * 3. logistic_regression * 4. mira */ val modelType = args.getOrElse("model", "logistic_regression") /** * What kind of learning rate schedule / regularization * should we use? * * Options: * 1. elastic_net * 2. adagrad * 3. passive_aggressive * 4. ftrl */ val optimizerType = args.getOrElse("optimizer", "elastic_net") /** Aggressiveness parameter for passive aggressive classifier **/ val aggressiveness = args.getOrElse("aggressiveness", "2.0").toDouble val finalThresholding = args.getOrElse("final_thresholding", "0.0").toDouble /** * Initial learning rate used for SGD learning. */ val initialLearningRate = args.getOrElse("rate", "0.1").toDouble /** Base of the exponential learning rate (e.g., 0.99^{# examples seen}). **/ val exponentialLearningRateBase = args.getOrElse("exponential_learning_rate_base", "1.0").toDouble /** Whether to use the exponential learning rate. If not chosen then the learning rate is like 1.0 / epoch. **/ val useExponentialLearningRate = args.boolean("exponential_learning_rate_base") /** * A fudge factor so that an "epoch" for the purpose of learning rate computation can be more than one example, * in which case the "epoch" will take a fractional amount equal to {# examples seen} / examples_per_epoch. */ val examplesPerEpoch = args.getOrElse("examples_per_epoch", "10000").toDouble /** How to subsample each class, in the case of imbalanced data. **/ val zeroClassProb = args.getOrElse("zero_class_prob", "1.0").toDouble val oneClassProb = args.getOrElse("one_class_prob", "1.0").toDouble /** * Weight on laplace regularization- a laplace prior on the parameters * sparsity inducing ala lasso */ val laplace = args.getOrElse("laplace", "0.0").toDouble /** * Weight on gaussian prior on the parameters * similar to ridge */ val gauss = args.getOrElse("gauss", "0.0").toDouble /** * Learning rate parameters for FTRL */ val ftrlAlpha = args.getOrElse("ftrlAlpha", "1.0").toDouble val ftrlBeta = args.getOrElse("ftrlBeta", "1.0").toDouble /** * Choose an optimizer to use */ val o = optimizerType match { case "elastic_net" => new ElasticNetOptimizer() case "adagrad" => new AdagradOptimizer() case "passive_aggressive" => new PassiveAggressiveOptimizer().setC(aggressiveness).isHinge(true) case "ftrl" => new FTRLOptimizer().setAlpha(ftrlAlpha).setBeta(ftrlBeta) case "control" => new ControlOptimizer() case "mira" => new MIRAOptimizer() } val optimizer = o.setGaussianRegularizationWeight(gauss) .setLaplaceRegularizationWeight(laplace) .setExamplesPerEpoch(examplesPerEpoch) .setUseExponentialLearningRate(useExponentialLearningRate) .setExponentialLearningRateBase(exponentialLearningRateBase) .setInitialLearningRate(initialLearningRate) /** Period of gradient truncation updates **/ val truncationPeriod = args.getOrElse("period", Int.MaxValue.toString).toInt /** * Aggressiveness of gradient truncation updates, how much shrinkage * is applied to the model's parameters */ val truncationAlpha = args.getOrElse("alpha", "0.0").toDouble /** * Threshold for applying gradient truncation updates * parameter values smaller than this in magnitude are truncated */ val truncationThresh = args.getOrElse("thresh", "0.0").toDouble /** Size of minibatch for mini-batch training, defaults to 1 which is just SGD. **/ val batchsz = args.getOrElse("mini_batch_size", "1").toInt override def miniBatchSize: Int = batchsz override def sampleProb(l: BinaryLabel): Double = { if (l.getValue < 0.5) zeroClassProb else oneClassProb } override def modelPostProcess(m: UpdateableLinearModel[BinaryLabel]): UpdateableLinearModel[BinaryLabel] = { m.thresholdParameters(finalThresholding) m.setArgString(args.toString) m.teardown() m } if(modelType == "mira" && optimizerType != "mira"){ throw new IllegalArgumentException("MIRA only uses a MIRAOptimizer"); } def getModel: UpdateableLinearModel[BinaryLabel] = { val model = modelType match { case "perceptron" => new Hinge(optimizer).setThreshold(0.0) case "linear_svm" => new Hinge(optimizer).setThreshold(1.0) case "logistic_regression" => new LogisticRegression(optimizer) case "mira" => new MIRA() } model.setTruncationPeriod(truncationPeriod) .setTruncationThreshold(truncationThresh) .setTruncationUpdate(truncationAlpha) model } val bins = args.getOrElse("bins", "100").toInt val trainer = if (args.boolean("large")) new LargeModelTrainer(this, bins) else new SmallModelTrainer(this) def train(instances: Pipe, instanceField: Symbol = 'instance, modelField: Symbol = 'model): Pipe = { trainer.train(instances, instanceField, modelField) } def reTrain(instances: Pipe, instanceField: Symbol, model: Pipe, modelField: Symbol): Pipe = { trainer.reTrain(instances, instanceField, model, modelField) } } ================================================ FILE: src/main/scala/com/etsy/conjecture/scalding/train/ClusteringModelTrainer.scala ================================================ package com.etsy.conjecture.scalding.train import cascading.pipe.Pipe import com.twitter.scalding._ import com.etsy.conjecture.data._ import com.etsy.conjecture.model._ import scala.collection.JavaConversions._ import java.io.File import scala.io.Source class ClusteringModelTrainer(args: Args, centers: Map[String, StringKeyedVector]) extends AbstractModelTrainer[ClusterLabel, ClusteringModel[ClusterLabel]] with ModelTrainerStrategy[ClusterLabel, ClusteringModel[ClusterLabel]] { // number of iterations for // sequential gradient descent var iters = args.getOrElse("iters", "1").toInt /* * Number of clusters to build */ val num_clusters = args.getOrElse("num_clusters","100").toInt /* * Error tolerance for the l1 projection in 'web scale kmeans' */ val error_tolerance = args.getOrElse("error_tolerance","0.01").toDouble /* * Ball radius for the l1 projection in 'web scale kmeans' */ val ball_radius = args.getOrElse("ball_radius","1.0").toDouble override def getIters: Int = iters def getModel: ClusteringModel[ClusterLabel] = { new KMeans(centers) .setNumClusters(num_clusters) .setL1ProjectionErrorTolerance(error_tolerance) .setL1ProjectionBallRadius(ball_radius) } val bins = args.getOrElse("bins", "100").toInt val trainer = if (args.boolean("large")) new LargeModelTrainer(this, bins) else new SmallModelTrainer(this) def train(instances: Pipe, instanceField: Symbol = 'instance, modelField: Symbol = 'model): Pipe = { trainer.train(instances, instanceField, modelField) } def reTrain(instances: Pipe, instanceField: Symbol, model: Pipe, modelField: Symbol): Pipe = { trainer.reTrain(instances, instanceField, model, modelField) } } ================================================ FILE: src/main/scala/com/etsy/conjecture/scalding/train/LargeModelTrainer.scala ================================================ package com.etsy.conjecture.scalding.train import cascading.flow._ import cascading.operation._ import cascading.pipe._ import cascading.pipe.joiner.InnerJoin import com.etsy.conjecture.data._ import com.etsy.conjecture.model._ import com.twitter.scalding._ class LargeModelTrainer[L <: Label, M <: UpdateableModel[L, M]](strategy: ModelTrainerStrategy[L, M], training_bins: Int) extends AbstractModelTrainer[L, M] { import Dsl._ def train(instances: Pipe, instanceField: Symbol = 'instance, modelField: Symbol = 'model): Pipe = { trainRecursively(None, modelField, binTrainingData(instances, instanceField), instanceField, strategy.getIters) } def reTrain(instances: Pipe, instanceField: Symbol, model: Pipe, modelField: Symbol): Pipe = { throw new UnsupportedOperationException("not implemented due to expensiveness of model duplication") } def binTrainingData(instances: Pipe, instanceField: Symbol): Pipe = { instances .project(instanceField) .map(instanceField -> 'bin) { b: LabeledInstance[L] => b.hashCode % training_bins } } // This implements a full iteration of training, ending with a pipe with a model. protected def trainIteration(modelPipe: Option[Pipe], modelField: Symbol, instancePipe: Pipe, instanceField: Symbol): Pipe = { val iterationField = '__iteration__ val modelCountField = '__model_count__ // Subsample instances. val subsampled = instancePipe.filter(instanceField) { i: LabeledInstance[L] => math.random < strategy.sampleProb(i.getLabel) } // Get models on each mapper. (modelPipe match { case Some(pipe) => subsampled.joinWithSmaller('bin -> 'bin, pipe, new InnerJoin(), training_bins) case _ => subsampled.map(instanceField -> (instanceField, modelField)) { x: LabeledInstance[L] => (x, strategy.getModel) } }) // Count iteration numbers. .insert(iterationField, 0) .insert(modelCountField, 1) // Convert instances to instance list. .map(instanceField -> instanceField) { i: LabeledInstance[L] => List(i) } // Perform map-side aggregation of models, which are then sent to a single reduce node for merging. .groupBy('bin) { _.reduce[(M, List[LabeledInstance[L]], Int, Int)]( (modelField, instanceField, iterationField, modelCountField) -> (modelField, instanceField, iterationField, modelCountField))(strategy.modelReduceFunction) .reducers(training_bins) } .mapTo((modelField, iterationField) -> modelField) { x: (M, Int) => strategy.endIteration(x._1, x._2, training_bins) } // flatten submodels and aggregate on different reducers. .flatMapTo(modelField -> ('param, 'value)) { m: M => println("epoch: " + m.getEpoch) m.setParameter("__epoch__", m.getEpoch) new Iterable[(String, Double)]() { def iterator() = { new Iterator[(String, Double)]() { val it = m.decompose def hasNext: Boolean = { it.hasNext } def next: (String, Double) = { val e = it.next(); (e.getKey, e.getValue) } } } } } .groupBy('param) { _.sum[Double]('value).forceToReducers } // Duplicate the summed parameters rather than duplicating the reconstructed model, for speed reasons. .flatMapTo(('param, 'value) -> ('bin, 'param, 'value)) { b: (String, Double) => (0 until training_bins).map { i => (i, b._1, b._2) } } // Reconstruct the model for each bin. Uses a hacked on Scalding operator due to kryo serialization not supporting copy(). .groupBy('bin) { _.every { pipe => new Every( pipe, ('param, 'value), new FoldAggregator[(String, Double), M]( { (model: M, param: (String, Double)) => if (param._1 == "__epoch__") { val epoch = (param._2 / training_bins).toLong println("epoch: " + epoch) model.setEpoch(epoch) } else { model.setParameter(param._1, param._2) } model }, strategy.getModel, modelField, implicitly[TupleConverter[(String, Double)]], implicitly[TupleSetter[M]]) { override def start(flowProcess: FlowProcess[_], call: AggregatorCall[M]) = call.setContext(strategy.getModel) }) } .reducers(training_bins) } .project('bin, modelField) } protected def trainRecursively(modelPipe: Option[Pipe], modelField: Symbol, instancePipe: Pipe, instanceField: Symbol, iterations: Int): Pipe = { val updatedPipe = trainIteration(modelPipe, modelField, instancePipe, instanceField) if (iterations == 1) { updatedPipe.filter('bin) { b: Int => b == 0 }.mapTo(modelField -> modelField) { strategy.modelPostProcess }.groupAll { _.pass } } else { trainRecursively(Some(updatedPipe), modelField, instancePipe, instanceField, iterations - 1) } } } ================================================ FILE: src/main/scala/com/etsy/conjecture/scalding/train/ModelTrainerStrategy.scala ================================================ package com.etsy.conjecture.scalding.train import cascading.pipe.Pipe import com.etsy.conjecture.data._ import com.etsy.conjecture.model._ import com.twitter.scalding._ trait ModelTrainerStrategy[L <: Label, M <: UpdateableModel[L, M]] extends Serializable { import Dsl._ // How many iterations of training to perform. def getIters: Int = 1 // The subclass just implements a thing that creates an initial model, from a set of // initial parameters. def getModel: M // Optionally subsample depending on the lable. def sampleProb(label: L): Double = 1.0 // Optionally perform some post-processing on the model after training. def modelPostProcess(m: M): M = m def miniBatchSize: Int = 1 // Function to merge two sub-models. // Can be overriden to change default behavior. def mergeModels(model1: M, model2: M, iteration1: Int, iteration2: Int): M = { model1.merge(model2, 1.0) model1 } // Do something at the end of the iteration. def endIteration(model: M, iteration: Int, models: Int): M = { model.reScale(1.0 / models) model } // Train the model on a mini batch. // Returns the trained model, remining instances from the list, and the iteration number. def updateModelOnMiniBatch(model: M, instances: List[LabeledInstance[L]], start_iteration: Int): (M, List[LabeledInstance[L]], Int) = { val batch_sz = miniBatchSize // might be something that needs computing who knows. val mini_batch = new java.util.ArrayList[LabeledInstance[L]](miniBatchSize) val n_batches = instances.size / batch_sz var batch = 0 val iterator = instances.iterator while (batch < n_batches) { // one extra iteration to get the remainder into the array. (0 until batch_sz).foreach { i => mini_batch.add(i, iterator.next) } model.update(mini_batch) batch += 1 } val remainder = iterator.toList (model, remainder, start_iteration + n_batches) } // This implements the associative operation of the model training. // Note that each tuple contains an instance not yet used for training. def modelReduceFunction(a: (M, List[LabeledInstance[L]], Int, Int), b: (M, List[LabeledInstance[L]], Int, Int)): (M, List[LabeledInstance[L]], Int, Int) = { if (a._3 > 0 && b._3 > 0) { // Both models have some prior training. val ua = updateModelOnMiniBatch(a._1, a._2, a._3) val ub = updateModelOnMiniBatch(b._1, b._2, b._3) // Merge together (mergeModels(ua._1, ub._1, ua._3, ub._3), ua._2 ++ ub._2, ua._3 + ub._3, a._4 + b._4) } else if (b._3 > 0) { // Only model b has some prior training. // Update model b using all instances. val uba = updateModelOnMiniBatch(b._1, a._2 ++ b._2, b._3) (uba._1, uba._2, uba._3, b._4) } else { // Either no model is trained, or only a is. // Update model a using whatever intances are available. val uab = updateModelOnMiniBatch(a._1, a._2 ++ b._2, a._3) (uab._1, uab._2, uab._3, a._4) } } } ================================================ FILE: src/main/scala/com/etsy/conjecture/scalding/train/MulticlassModelTrainer.scala ================================================ package com.etsy.conjecture.scalding.train import cascading.pipe.Pipe import com.twitter.scalding._ import com.etsy.conjecture.data._ import com.etsy.conjecture.model._ import scala.io.Source import scala.collection.JavaConversions._ class MulticlassModelTrainer(args: Args, categories: Array[String]) extends AbstractModelTrainer[MulticlassLabel, UpdateableMulticlassLinearModel] with ModelTrainerStrategy[MulticlassLabel, UpdateableMulticlassLinearModel] { /** * Number of iterations for sequential gradient descent */ val iters = args.getOrElse("iters", "1").toInt /** * What type of linear model should be used? * Options are: * 1. perceptron * 2. linear_svm * 3. logistic_regression * 4. mira */ val modelType = args.getOrElse("model", "logistic_regression").toString /** * What kind of learning rate schedule / regularization * should we use? * * Options: * 1. elastic_net * 2. adagrad * 3. passive_aggressive * 4. ftrl */ val optimizerType = args.getOrElse("optimizer", "elastic_net") /** Aggressiveness parameter for passive aggressive classifier **/ val aggressiveness = args.getOrElse("aggressiveness", "2.0").toDouble val finalThresholding = args.getOrElse("final_thresholding", "0.0").toDouble /** * Initial learning rate used for SGD learning. */ val initialLearningRate = args.getOrElse("rate", "0.1").toDouble /** Base of the exponential learning rate (e.g., 0.99^{# examples seen}). **/ val exponentialLearningRateBase = args.getOrElse("exponential_learning_rate_base", "1.0").toDouble /** Whether to use the exponential learning rate. If not chosen then the learning rate is like 1.0 / epoch. **/ val useExponentialLearningRate = args.boolean("exponential_learning_rate_base") /** * A fudge factor so that an "epoch" for the purpose of learning rate computation can be more than one example, * in which case the "epoch" will take a fractional amount equal to {# examples seen} / examples_per_epoch. */ val examplesPerEpoch = args.getOrElse("examples_per_epoch", "10000").toDouble /** * Weight on laplace regularization- a laplace prior on the parameters * sparsity inducing ala lasso */ val laplace = args.getOrElse("laplace", "0.0").toDouble /** * Weight on gaussian prior on the parameters * similar to ridge */ val gauss = args.getOrElse("gauss", "0.0").toDouble /** Period of gradient truncation updates **/ val truncationPeriod = args.getOrElse("period", Int.MaxValue.toString).toInt /** * Aggressiveness of gradient truncation updates, how much shrinkage * is applied to the model's parameters */ val truncationAlpha = args.getOrElse("alpha", "0.0").toDouble /** * Threshold for applying gradient truncation updates * parameter values smaller than this in magnitude are truncated */ val truncationThresh = args.getOrElse("thresh", "0.0").toDouble /** * Learning rate parameters for FTRL */ val ftrlAlpha = args.getOrElse("ftrlAlpha", "1.0").toDouble val ftrlBeta = args.getOrElse("ftrlBeta", "1.0").toDouble val classSampleProbabilities = args.optional("class_probs") .map { entries : String => entries.split(",").map { s:String => val p = s.split(":") (p(0), p(1).toDouble) }.toMap } .getOrElse(Map[String, Double]()) val classSampleProbabilityFile = args.optional("class_prob_file") // stores sampling rates for different classes lazy val probabilityMap : Map[String, Double] = { val probs = categories.map{ c:String => (c, classSampleProbabilities.getOrElse(c, 1.0)) }.toMap classSampleProbabilityFile match { case Some(f) => probs ++ Source.fromFile(f).getLines().map{ s:String => val p = s.split(":") (p(0), p(1).toDouble) }.toMap case None => probs } } override def getIters: Int = iters override def sampleProb(l : MulticlassLabel) : Double = { probabilityMap.getOrElse(l.getLabel(), 1.0) } override def modelPostProcess(m: UpdateableMulticlassLinearModel) : UpdateableMulticlassLinearModel = { m.thresholdParameters(finalThresholding) m.setArgString(args.toString) m.teardown() m } /** * Choose an optimizer to use */ val o = optimizerType match { case "elastic_net" => new ElasticNetOptimizer() case "adagrad" => new AdagradOptimizer() case "passive_aggressive" => new PassiveAggressiveOptimizer().setC(aggressiveness).isHinge(true) case "ftrl" => new FTRLOptimizer().setAlpha(ftrlAlpha).setBeta(ftrlBeta) case "mira" => new MIRAOptimizer() } val optimizer = o.setGaussianRegularizationWeight(gauss) .setLaplaceRegularizationWeight(laplace) .setExamplesPerEpoch(examplesPerEpoch) .setUseExponentialLearningRate(useExponentialLearningRate) .setExponentialLearningRateBase(exponentialLearningRateBase) .setInitialLearningRate(initialLearningRate) def buildMultiClassModel(buildSubModel : () => UpdateableLinearModel[BinaryLabel], categories : Array[String]) : UpdateableMulticlassLinearModel = { val param = categories.map{ i : String => (i, buildSubModel().setTruncationPeriod(truncationPeriod) .setTruncationThreshold(truncationThresh) .setTruncationUpdate(truncationAlpha)) }.toMap new UpdateableMulticlassLinearModel(new java.util.HashMap[String,UpdateableLinearModel[BinaryLabel]](param) ) } if(modelType == "mira" && optimizerType != "mira"){ throw new IllegalArgumentException("MIRA only uses a MIRAOptimizer"); } def getModel: UpdateableMulticlassLinearModel = { val model = modelType match { case "perceptron" => buildMultiClassModel({() => new Hinge(optimizer).setThreshold(0.0)}, categories) case "linear_svm" => buildMultiClassModel({() => new Hinge(optimizer).setThreshold(1.0)}, categories) // TODO: re-make proper multiclass logistic regression instead of this one vs all thing. case "logistic_regression" => buildMultiClassModel({() => new LogisticRegression(optimizer)}, categories) // TODO: re-make multiclass mira. case "mira" => buildMultiClassModel({() => new MIRA()}, categories) } model.setModelType(modelType) model } val bins = args.getOrElse("bins", "100").toInt val trainer = if (args.boolean("large")) new LargeModelTrainer(this, bins) else new SmallModelTrainer(this) /** Size of minibatch for mini-batch training, defaults to 1 which is just SGD. **/ val batchsz = args.getOrElse("mini_batch_size", "1").toInt override def miniBatchSize: Int = batchsz def train(instances: Pipe, instanceField: Symbol = 'instance, modelField: Symbol = 'model): Pipe = { trainer.train(instances, instanceField, modelField) } def reTrain(instances: Pipe, instanceField: Symbol, model: Pipe, modelField: Symbol): Pipe = { trainer.reTrain(instances, instanceField, model, modelField) } } ================================================ FILE: src/main/scala/com/etsy/conjecture/scalding/train/RegressionModelTrainer.scala ================================================ package com.etsy.conjecture.scalding.train import cascading.pipe.Pipe import com.twitter.scalding._ import com.etsy.conjecture.data._ import com.etsy.conjecture.model._ class RegressionModelTrainer(args: Args) extends AbstractModelTrainer[RealValuedLabel, UpdateableLinearModel[RealValuedLabel]] with ModelTrainerStrategy[RealValuedLabel, UpdateableLinearModel[RealValuedLabel]] { // number of iterations for // sequential gradient descent val iters = args.getOrElse("iters", "1").toInt override def getIters: Int = iters // weight on laplace regularization- a laplace prior on the parameters // sparsity inducing ala lasso val laplace = args.getOrElse("laplace", "0.5").toDouble // weight on gaussian prior on the parameters // similar to ridge val gauss = args.getOrElse("gauss", "0.5").toDouble val modelType = "least_squares" // just one model type for regression at the moment /** * What kind of learning rate schedule / regularization * should we use? * * Options: * 1. elastic_net * 2. adagrad * 3. passive_aggressive * 4. ftrl */ val optimizerType = args.getOrElse("optimizer", "elastic_net") // aggressiveness parameter for passive aggressive classifier val aggressiveness = args.getOrElse("aggressiveness", "2.0").toDouble val ftrlAlpha = args.getOrElse("ftrlAlpha", "1.0").toDouble val ftrlBeta = args.getOrElse("ftrlBeta", "1.0").toDouble // initial learning rate used for SGD learning. this decays according to the // inverse of the epoch val initialLearningRate = args.getOrElse("rate", "0.1").toDouble // Base of the exponential learning rate (e.g., 0.99^{# examples seen}). val exponentialLearningRateBase = args.getOrElse("exponential_learning_rate_base", "1.0").toDouble // Whether to use the exponential learning rate. If not chosen then the learning rate is like 1.0 / epoch. val useExponentialLearningRate = args.boolean("exponential_learning_rate_base") // A fudge factor so that an "epoch" for the purpose of learning rate computation can be more than one example, // in which case the "epoch" will take a fractional amount equal to {# examples seen} / examples_per_epoch. val examplesPerEpoch = args.getOrElse("examples_per_epoch", "10000").toDouble /** * Choose an optimizer to use */ val o = optimizerType match { case "elastic_net" => new ElasticNetOptimizer() case "adagrad" => new AdagradOptimizer() case "passive_aggressive" => new PassiveAggressiveOptimizer().setC(aggressiveness).isHinge(false) case "ftrl" => new FTRLOptimizer().setAlpha(ftrlAlpha).setBeta(ftrlBeta) } val optimizer = o.setExamplesPerEpoch(examplesPerEpoch) .setUseExponentialLearningRate(useExponentialLearningRate) .setExponentialLearningRateBase(exponentialLearningRateBase) .setInitialLearningRate(initialLearningRate) def getModel: UpdateableLinearModel[RealValuedLabel] = { val model = modelType match { case "least_squares" => new LeastSquaresRegressionModel(optimizer) } model } val bins = args.getOrElse("bins", "100").toInt val trainer = if (args.boolean("large")) new LargeModelTrainer(this, bins) else new SmallModelTrainer(this) def train(instances: Pipe, instanceField: Symbol = 'instance, modelField: Symbol = 'model): Pipe = { trainer.train(instances, instanceField, modelField) } def reTrain(instances: Pipe, instanceField: Symbol, model: Pipe, modelField: Symbol): Pipe = { trainer.reTrain(instances, instanceField, model, modelField) } } ================================================ FILE: src/main/scala/com/etsy/conjecture/scalding/train/SmallModelTrainer.scala ================================================ package com.etsy.conjecture.scalding.train import cascading.pipe.Pipe import com.etsy.conjecture.data._ import com.etsy.conjecture.model._ import com.twitter.scalding._ class SmallModelTrainer[L <: Label, M <: UpdateableModel[L, M]](strategy: ModelTrainerStrategy[L, M]) extends AbstractModelTrainer[L, M] { import Dsl._ // Functionality to train a small model (hundreds of thousands of features, arbitrarily many instances) // Trains a model on each mapper, then aggregates them on one reducer. // The last step is expensive if the dimensionality is great, since the reducer has to deserialize large StringKeyedVectors. def train(instances: Pipe, instanceField: Symbol = 'instance, modelField: Symbol = 'mode): Pipe = { // Begin training. trainRecursively(None, modelField, instances, instanceField, strategy.getIters) } // Additional training for a small model. def reTrain(instances: Pipe, instanceField: Symbol, model: Pipe, modelField: Symbol): Pipe = { // Begin training. trainRecursively(Some(model), modelField, instances, instanceField, strategy.getIters) } // This implements a full iteration of training, ending with a pipe with a model. protected def trainIteration(modelPipe: Option[Pipe], modelField: Symbol, instancePipe: Pipe, instanceField: Symbol): Pipe = { val iterationField: Symbol = '__iteration__ val modelCountField: Symbol = '__model_count__ // Subsample instances. val subsampled = instancePipe.filter(instanceField) { i: LabeledInstance[L] => math.random < strategy.sampleProb(i.getLabel) } // Get models on each mapper. (modelPipe match { case Some(pipe) => subsampled.project(instanceField).crossWithTiny(pipe.project(modelField)) case _ => subsampled.mapTo(instanceField -> (instanceField, modelField)) { x: LabeledInstance[L] => (x, strategy.getModel) } }) // Count iteration numbers. .insert(iterationField, 0) .insert(modelCountField, 1) // Convert instances to instance list. .map(instanceField -> instanceField) { i: LabeledInstance[L] => List(i) } // Perform map-side aggregation of models, which are then sent to a single reduce node for merging. .groupAll { _.reduce[(M, List[LabeledInstance[L]], Int, Int)]( (modelField, instanceField, iterationField, modelCountField) -> (modelField, instanceField, iterationField, modelCountField))(strategy.modelReduceFunction) } .mapTo((modelField, iterationField, modelCountField) -> modelField) { x: (M, Int, Int) => strategy.endIteration(x._1, x._2, x._3) } } protected def trainRecursively(modelPipe: Option[Pipe], modelField: Symbol, instancePipe: Pipe, instanceField: Symbol, iterations: Int): Pipe = { val updatedPipe = trainIteration(modelPipe, modelField, instancePipe, instanceField) if (iterations == 1) { updatedPipe.map(modelField -> modelField) { strategy.modelPostProcess } } else { trainRecursively(Some(updatedPipe), modelField, instancePipe, instanceField, iterations - 1) } } } ================================================ FILE: src/main/scala/com/etsy/conjecture/scalding/util/BaseGridSearcher.scala ================================================ package com.etsy.conjecture.scalding.util import scala.util.Random import com.etsy.conjecture.scalding.util._ import com.etsy.conjecture.data._ import com.etsy.conjecture.model._ import com.twitter.scalding._ import cascading.tuple.Fields /** * Interface for using conjecture's hyperparamter tuner * See DefaultGridSearcher job for an example of how to extend this class. */ abstract class BaseGridSearcher(args : Args) extends Job(args) { val input = args.getOrElse("input", "specify_an_input_dir") val out_dir = args.getOrElse("out_dir", "hypertuned") val folds = args.getOrElse("folds", "0").toInt val problemName = args.getOrElse("name", "demo_problem") val xmx = args.getOrElse("xmx", "3").toInt val containerMemory = (xmx * 1024 * 1.16).toInt // Let the user configure the field names on the command line. val data_field_names = args.getOrElse("data_fields", "instance").split(",") val data_fields = data_field_names.tail.foldLeft(new Fields(data_field_names.head)) { (x,y) => x.append(new Fields(y)) } val instance_field = Symbol(args.getOrElse("instance_field", "instance")) val salt = args.getOrElse("salt", "") val numTrials = args.getOrElse("num_trials", "10").toInt val instances = SequenceFile(input, data_fields).project(instance_field) //Define your parameters to optimize val opts: DynamicOptions val parameters: Seq[HyperParameter[_]] //Define the searcher to use based on the classifier val searcher: HyperparameterSearcher[_,_,_] override def config: Map[AnyRef, AnyRef] = super.config ++ Map("mapred.child.java.opts" -> "-Xmx%dG".format(xmx), "mapreduce.map.memory.mb" -> containerMemory.toString, "mapreduce.reduce.memory.mb" -> containerMemory.toString ) } ================================================ FILE: src/main/scala/com/etsy/conjecture/scalding/util/DynamicOptions.scala ================================================ package com.etsy.conjecture.scalding.util import java.io.Serializable import com.twitter.scalding.Args import scala.collection.mutable.{ArrayBuffer, HashMap} import scala.reflect.runtime.universe._ class DynamicOptions(args: Args) extends Serializable { private val opts = new HashMap[String, DynamicOption[_]] parse(args) def get(key: String) = opts.get(key) def size = opts.size var strict = true def values = opts.values def +=(c: DynamicOption[_]): this.type = { if (opts.contains(c.name)) throw new Error("DynamicOption " + c.name + " already exists.") opts(c.name) = c this } def -=(c: DynamicOption[_]): this.type = { opts -= c.name this } /** The arguments that were unqualified by dashed options. * Currently unused but held for future work. */ private val _remaining = new ArrayBuffer[String] def remaining: Seq[String] = _remaining /** Parse sequence of command-line arguments. */ def parse(args: Args): Unit = { args.m.filter(!_._2.isEmpty).foreach { case (key, listValues) => //Only take first value from the values val value = listValues.head key match { case k: String if (opts.contains(k)) => opts(k).parseValue(value) case "" => _remaining += value case _ => { opts.+=((key, new DynamicOption(key, value))) opts(key).parseValue(value) } } } } def unParse: Args = { val newM = opts.map{cmd => cmd._1->List(cmd._2.value.toString())}.toMap new Args(newM) } class DynamicOption[T](val name:String, val defaultValue:T)(implicit m: Manifest[T]) extends com.etsy.conjecture.scalding.util.DynamicOption[T] with Serializable { def this(name: String)(implicit m: Manifest[T]) = this(name, null.asInstanceOf[T]) DynamicOptions.this += this private def valueClass: Class[_] = m.runtimeClass private def valueType = m.runtimeClass private def matches(str: String): Boolean = str == ("--" + name) var _value = defaultValue def value: T = { _value } def setValue(v: T) { _value = v } def hasValue = !(valueType eq classOf[Nothing]) var setCount = 0 /** Parses each of the supported value, increments set counter, and sets the value */ def parseValue(args: String) = { setCount += 1 setValue(valueType match { case t if t eq classOf[List[String]] => args.asInstanceOf[T] case t if t eq classOf[List[Int]] => args.map(_.toInt).asInstanceOf[T] case t if t eq classOf[List[Double]] => args.map(_.toDouble).asInstanceOf[T] case t if t eq classOf[Char] => args.head.asInstanceOf[T] case t if t eq classOf[String] => args.asInstanceOf[T] case t if t eq classOf[Short] => args.toShort.asInstanceOf[T] case t if t eq classOf[Int] => args.toInt.asInstanceOf[T] case t if t eq classOf[Long] => args.toLong.asInstanceOf[T] case t if t eq classOf[Double] => args.toDouble.asInstanceOf[T] case t if t eq classOf[Float] => args.toFloat.asInstanceOf[T] case t if t eq classOf[Boolean] => args.toBoolean.asInstanceOf[T] case otw => throw new Error("DynamicOption does not handle values of type " + otw) }) } } override def toString: String = values.map(_.toString).mkString("\t") } trait DynamicOption[T] { def name: String def defaultValue: T def value: T def setValue(v: T): Unit def hasValue: Boolean def setCount: Int def wasSet = setCount > 0 override def toString: String = { if (hasValue) value match { case a: Seq[_] => Seq("--" + name + " ") ++ a.map(_.toString) case "" => Seq() case a: Any => Seq("--" + name + " " + value.toString) } else Seq() }.mkString("; ") override def hashCode = name.hashCode override def equals(other:Any) = name.equals(other) } ================================================ FILE: src/main/scala/com/etsy/conjecture/scalding/util/HyperparameterSearcher.scala ================================================ package com.etsy.conjecture.scalding.util import java.io.Serializable import cascading.pipe.Pipe import com.etsy.conjecture.data._ import com.etsy.conjecture.evaluation._ import com.etsy.conjecture.model._ import com.etsy.conjecture.scalding.evaluate.{BinaryEvaluator, GenericEvaluator, MulticlassEvaluator, RegressionEvaluator} import com.etsy.conjecture.scalding.train._ import com.twitter.scalding.Dsl._ import com.twitter.scalding._ import scala.util.Random /** * Samples random parameter values to perform a fast efficient hyperparameter search */ abstract class HyperparameterSearcher[L <: Label, M <: UpdateableModel[L, M], E <: ModelEvaluation[L]] (val options: DynamicOptions, val parameters: Seq[HyperParameter[_]], val numTrials: Int, rng: Random = new Random(0)) extends Serializable { //Map of trial id to a randomly sampled set of parameters val settings = (0 until numTrials).map { trial : Int => trial -> sampledParameters(rng) }.toMap def getModelTrainer(args: Args): ModelTrainerStrategy[L, M] val evaluator: GenericEvaluator[L] //draw parameter values from given sample method //save values in a new Arg instance def sampledParameters(rng: Random): Args = { parameters.foreach(_.set(rng)) options.unParse } def search (instances: Pipe, instance_field: Symbol): (Pipe, Pipe) = { //Split train test by ratio //TODO Should make this into a parameter val splitSet = instances.map(instance_field -> '__fold) { li: LabeledInstance[L] => rng.nextInt(10) <= 7 } val trainSet = splitSet.filter('__fold) { foldId: Boolean => foldId } val testSet = splitSet.filter('__fold) { foldId: Boolean => !foldId } //Restructure data pipes into trainable format and assign an ID that defines the random setting to use val train = generateTrials(trainSet, instance_field) val test = generateTrials(testSet, instance_field) val models = trainTrials(train) val rawResults = evaluate(models, test) val runResults = rawResults .mapTo(('settings, 'eval) -> 'result) { x: (Args, Double) => x._2 + "\t" + x._1.toString } //Tally up test accs to find metrics on tested param values val paramReport = createParameterReport(rawResults) (runResults, paramReport) } def generateTrials(instances: Pipe, instance_field: Symbol): Pipe = { instances .flatMapTo('instance -> ('instance, 'trial)) { instance: LabeledInstance[L] => settings.keySet.map { trial: Int => (instance, trial) } } .groupBy('trial) { _.toList[LabeledInstance[L]]('instance -> 'instances).reducers(1000) }.project('trial, 'instances) } def trainTrials(instances: Pipe): Pipe = { instances .mapTo(('trial, 'instances) -> ('trial, 'model)) { x: (Int, List[LabeledInstance[L]]) => //In the unlike case that a setting does not exist, use the default value val args: Args = settings.getOrElse(x._1, options.unParse) val instanceSet = x._2 val modelTrainer = getModelTrainer(args) val model = modelTrainer.getModel //Train model instanceSet.foreach(model.update) (x._1, model) } } def evaluate (models: Pipe, testSet: Pipe): Pipe = { val eval = models .joinWithSmaller('trial -> 'trial, testSet) .mapTo(('model, 'instances, 'trial) -> ('eval, 'settings)) { x: (M, List[LabeledInstance[L]], Int) => val model = x._1 val testList = x._2 val args = settings.getOrElse(x._3, options.unParse) val acc = evaluateAccuracy(testList, model) (acc, args) } eval } def evaluateAccuracy(instances: List[LabeledInstance[L]], model: M): Double = { val eval = evaluator.build instances.map { instance: LabeledInstance[L] => val realLabel = instance.getLabel val prediction = model.predict(instance.getVector) eval.add(realLabel, prediction) } val agg = new EvaluationAggregator[L]() agg.add(eval) agg.getValue("Acc (avg)") } //Tally up evaluation scores of random parameter values and create a report of the mean/stdDev/max/count_of_runs def createParameterReport(rawResults: Pipe): Pipe = { rawResults .groupAll{ _.toList[(Args, Double)](('settings, 'eval) -> 'results) } .mapTo('results -> 'report) { results: List[(Args,Double)] => results.map{ x => options.parse(x._1) parameters.foreach(_.accumulate(x._2)) } parameters.map(_.report).mkString("\n") } } } class BinaryHyperparameterSearcher(option: DynamicOptions, parameters: Seq[HyperParameter[_]], numTrials: Int) extends HyperparameterSearcher[BinaryLabel, UpdateableLinearModel[BinaryLabel], BinaryModelEvaluation](option, parameters, numTrials) { val evaluator = new BinaryEvaluator() def getModelTrainer(args: Args) = new BinaryModelTrainer(args) } class RegressionHyperparameterSearcher(option: DynamicOptions, parameters: Seq[HyperParameter[_]], numTrials: Int) extends HyperparameterSearcher[RealValuedLabel, UpdateableLinearModel[RealValuedLabel], RegressionModelEvaluation](option, parameters, numTrials) { val evaluator = new RegressionEvaluator() def getModelTrainer(args: Args) = new RegressionModelTrainer(args) } class MulticlassHyperparameterSearcher(option: DynamicOptions, parameters: Seq[HyperParameter[_]], numTrials: Int, categories: Array[String]) extends HyperparameterSearcher[MulticlassLabel, UpdateableMulticlassLinearModel, MulticlassModelEvaluation](option, parameters, numTrials) { val evaluator = new MulticlassEvaluator(categories) def getModelTrainer(args: Args) = new MulticlassModelTrainer(args, categories) } /** * Sampling method for hyperparameters * Also defines how to bucket parameter values and accuracies for hyperparameter reports */ trait ParameterSampler[T] extends Serializable { def sample(rng: scala.util.Random): T //Array corresponds to bucketed parameter value, then the sum, sumSq, max, count of times that param bucket was run val buckets: Array[(T, Double, Double, Double, Int)] def valueToBucket(v: T): Int def accumulate(value: T, d: Double) { val v = math.max(0, math.min(valueToBucket(value), buckets.length-1)) val (_, sum, sumSq, max, count) = this.buckets(v) this.buckets(v) = (value, sum+d, sumSq+d*d, math.max(max, d), count+1) } } /** * Samples uniformly one value from the sequence. */ class SampleFromSeq[T](seq: Seq[T]) extends ParameterSampler[T] { val buckets = seq.map(s => (s, 0.0, 0.0, 0.0, 0)).toArray def valueToBucket(v: T) = buckets.toSeq.map(_._1).indexOf(v) def sample(rng: Random) = seq(rng.nextInt(seq.length)) } /** * Samples uniformly a double that falls within the range */ class UniformDoubleSampler(lower: Double, upper: Double, numBuckets: Int = 10) extends ParameterSampler[Double] { val dif = upper - lower val buckets = (0 to numBuckets).map(i => (0.0, 0.0, 0.0, 0.0, 0)).toArray def valueToBucket(d: Double) = (numBuckets*(d - lower)/dif).toInt def sample(rng: Random) = rng.nextDouble()*dif + lower } /** * Samples Doubles in the range such that their logarithm is uniform. * Useful for learning rates, variances, alphas, and other things which * vary in order of magnitude. */ class LogUniformDoubleSampler(lower: Double, upper: Double, numBuckets: Int = 10) extends ParameterSampler[Double] { val inner = new UniformDoubleSampler(math.log(lower), math.log(upper), numBuckets) def valueToBucket(v: Double) = inner.valueToBucket(math.log(v)) val buckets = (0 to numBuckets).map(i => (0.0, 0.0, 0.0, 0.0, 0)).toArray def sample(rng: Random) = math.exp(inner.sample(rng)) } /** * A container for a hyperparameter * @param option The DynamicOption wrapper for the parameter * @param sampler Sampler to use to return values for the parameter */ class HyperParameter[T](option: DynamicOption[T], val sampler: ParameterSampler[T]) { val buckets = sampler.buckets def set(rng: Random) { option.setValue(sampler.sample(rng)) } def accumulate(objective: Double) { sampler.accumulate(option.value, objective) } def report(): String = { val buff = new StringBuilder("Parameter: "+option.name+"\tMean\tStdDev\tMax\tCount\n") for ((value, sum, sumSq, max, count) <- buckets) { val mean = sum/count val stdDev = math.sqrt(sumSq/count - mean*mean) val metrics = value match { case v: Double => Vector(f"${v.toDouble}%2.15f", f"$mean%2.4f", f"$stdDev%2.4f",f"$max%1.4f", count).mkString("\t") case _ => Vector(f"${value.toString}%20s", f"$mean%2.4f", f"$stdDev%2.4f",f"$max%1.4f", count).mkString("\t") } buff.append(metrics + "\n") } buff.toString() } } ================================================ FILE: src/main/scala/com/etsy/conjecture/text/FeatureHelper.scala ================================================ package com.etsy.conjecture.text import com.etsy.conjecture.data.{ AbstractInstance, BinaryLabeledInstance, LabeledInstance, StringKeyedVector } import com.twitter.algebird.Operators._ import cascading.tuple.Fields import cascading.pipe.Pipe import scala.collection.JavaConverters._ import spray.json._ import DefaultJsonProtocol._ object FeatureHelper { import com.twitter.scalding.Dsl._ def keepFeaturesWithCountGreaterThan(pipe: Pipe, instance_field: Fields, n: Int): Pipe = { val counts = pipe .flatMapTo(instance_field -> ('term, '__count)) { v: AnyRef => val vector = v match { case skv: StringKeyedVector => skv case ins: AbstractInstance[_] => ins.getVector case lin: LabeledInstance[_] => lin.getVector case _ => throw new IllegalArgumentException("keepFeaturesWithCountGreaterThan does not expect class: " + v.getClass.getName) } vector.keySet.asScala.map { k => k -> 1 } } .groupBy('term) { _.sum[Long]('__count) } .filter('__count) { c: Long => c > n } .mapTo('term -> 'set) { t: String => Set(t) } .groupAll { _.sum[Set[String]]('set) } pipe .crossWithTiny(counts) .map(instance_field.append('set) -> instance_field) { x: (AnyRef, Set[String]) => val skv = x._1 match { case s: StringKeyedVector => s case i: AbstractInstance[_] => i.getVector case l: LabeledInstance[_] => l.getVector case _ => throw new IllegalArgumentException("keepFeaturesWithCountGreaterThan does not expect class: " + x._1.getClass.getName) } val it = skv.iterator while (it.hasNext) { val e = it.next if (!x._2.contains(e.getKey)) { it.remove } } x._1 } } def nGramsUpTo(string: String, n: Int = 2, prefix: String = ""): List[String] = { val toks = Text(string.toLowerCase).standardTextFilter.toString.split(" ").toList val toks_pad = "" +: toks :+ "" val grams = (1 to n).map { m => toks_pad.sliding(m).toList.map { p => p.mkString("::") } }.foldLeft(List[String]()) { _ ++ _ } grams.filter { g => g != "" }.map { g => prefix + g } } def stringListToSKV(list: List[String], weight: Double = 1.0): StringKeyedVector = { val skv = new StringKeyedVector(); list.foreach { f => skv.setCoordinate(f, weight) } skv } def getEmailBody(body: String): Option[String] = { val p = parseEmailBodyToTextAndType(body) if (p._1 != null) Some(p._1) else None } def parseEmailBodyToTextAndType(body: String): (String, String) = { try { val email = JsonParser(body).convertTo[List[Map[String, String]]] val textParts = email.filter(part => part("type") == "text/plain") if (textParts.length > 0) (textParts.map(part => part("body")).mkString(" "), "text/plain") else { val htmlParts = email.filter(part => part("type") == "text/html") if (htmlParts.length > 0) (htmlParts.map(part => part("body")).mkString(" "), "text/html") else (null, "filter") // Filter this email } } catch { case _ : Exception => (null, "filter") } } } ================================================ FILE: src/main/scala/com/etsy/conjecture/text/Text.scala ================================================ package com.etsy.conjecture.text case class Text(val input: String) { private implicit def text2str(txt: Text): String = txt.input private implicit def str2text(str: String): Text = new Text(str) override def toString = input.toString def replaceNumbers(replacement: String = "_num_") = Text(input.replaceAll("[0-9]+", replacement).replaceAll(replacement + "\\s+" + replacement, replacement)) def replaceHTMLEscapes(replacement: String = " ") = Text(input.replaceAll("&[^;]+;", replacement)) def removeHTMLTags() = Text(input.replaceAll("<.*?>", " ")) //Text(XML.loadString(input).text) def replaceHTMLTags(replacement: String = " ") = Text(input.replaceAll("<[^>]+>", " ")) def replaceNonAlphaNumeric(replacement: String = " ") = Text(input.replaceAll("[^a-zA-Z0-9\\.\\s\\-]+", replacement)) def replaceNonAlphaNumericUnderscore(replacement: String = " ") = Text(input.replaceAll("[^a-zA-Z0-9\\.\\s\\-_]+", replacement)) def replaceNonAlpha(replacement: String = " ") = Text(input.replaceAll("[^a-zA-Z]+", replacement)) def collapseHyphens() = Text(input.replaceAll("--+", "--")) def collapseUnderscores() = Text(input.replaceAll("__+", "__")) def collapsePeriods() = Text(input.replaceAll("\\.\\.+", "..")) def toLowerCase() = Text(input.toLowerCase) def toUpperCase() = Text(input.toUpperCase) def stripPunctuation() = Text(input.replaceAll("^[^A-Za-z0-0]+", "").replaceAll("[^A-Za-z0-9]+$", "")) // compact any white space def collapse() = Text(input.replaceAll("\\s+", " ")) // remove any whitespace from the right of a string def rstrip() = Text(input.replaceAll("\\s+$", "")) // remove any whitespace from the left of a string def lstrip() = Text(input.replaceAll("^\\s+", "")) // remove any leading or trailing whitespace def strip() = Text(input.trim) // clean up any whitespace def wsclean() = strip().collapse() // remove any unprintable non-ASCII characters def removeUnprintables(input: String) = Text(input.replaceAll("[^\\x20-\\x7E]", "")) def collapseWhitespaceAndPunc = Text(input.replaceAll("\\s+", " ") .replaceAll("[\\-]+", "-") .replaceAll("[\\.]+", ".")) def standardTextFilter = Text(removeHTMLTags() .replaceHTMLEscapes() .replaceNumbers() .replaceNonAlphaNumericUnderscore() .collapseHyphens() .collapseUnderscores() .wsclean()) def toListFromShingles(n: Int, ns: Int*): List[String] = (List(n) ++ ns.toList).flatMap{ i: Int => input.sliding(i) }.toList def toSequenceFromShingles(n: Int, ns: Int*): TextSequence = new TextSequence(toListFromShingles(n, ns: _*)) def toList(sep: String = " "): List[String] = input.split(sep).toList def toSequence(sep: String = " "): TextSequence = new TextSequence(toList(sep)) def isEmpty(): Boolean = input.isEmpty() } ================================================ FILE: src/main/scala/com/etsy/conjecture/text/TextSequence.scala ================================================ package com.etsy.conjecture.text import com.etsy.conjecture.data.{BinaryLabeledInstance,BinaryLabel,MulticlassLabel,MulticlassLabeledInstance} case class TextSequence(tokens: Seq[String]) { def ++(that: TextSequence): TextSequence = TextSequence(tokens ++ that.tokens) def mkString(glue: String = " "): String = tokens.mkString(glue) override def toString = mkString(" ") def intersect(that: TextSequence): TextSequence = TextSequence(tokens.intersect(that.tokens)) def filterBlank = TextSequence(tokens.filter { x => x.isEmpty }) def filterStopwords = TextSequence(tokens.filter { x => !Stopwords(x.toLowerCase) }) def stopwords = TextSequence(tokens.filter { x => Stopwords(x.toLowerCase) }) def filterBadwords = TextSequence(tokens.filter { x => !BadWords(x.toLowerCase) }) def badwords = TextSequence(tokens.filter { x => BadWords(x.toLowerCase) }) def filterAllCaps = TextSequence(tokens.filter { x => !x.matches("^[A-Z]+$") }) def allCaps = TextSequence(tokens.filter { x => x.matches("^[A-Z]+$") }) def filterCapitalized = TextSequence(tokens.filter { x => !x.matches("^[A-Z][^A-Z]*") }) def capitalized = TextSequence(tokens.filter { x => x.matches("^[A-Z][^A-Z]*") }) def filterLowercase = TextSequence(tokens.filter { x => !x.matches("^[a-z]+$") }) def allLowercase = TextSequence(tokens.filter { x => x.matches("^[a-z]+$") }) def filterURLs = TextSequence(tokens.filter { x => !x.matches("^https?://.+") }) def allURLs = TextSequence(tokens.filter { x => x.matches("^https?://.+") }) def filterListings = TextSequence(tokens.filter { x => !x.matches("^https?://.+etsy.+/listing/[0-9]+.*") }) def allListings = TextSequence(tokens.filter { x => x.matches("^https?://.+etsy.+/listing/[0-9]+.*") }) def size: Int = tokens.size def stopWordCount: Int = stopwords.size def stopWordFraq(bins: Int = 10): Int = (math.round(bins * stopWordCount / size) / bins.toDouble).toInt def badWordCount: Int = badwords.size def badWordFraq(bins: Int = 10): Int = (math.round(bins * badWordCount / size) / bins.toDouble).toInt def capsCount: Int = allCaps.size def capFraq(bins: Int = 10): Int = (math.round(bins * capsCount / size) / bins.toDouble).toInt def urlCount: Int = allURLs.size def urlFraq(bins: Int = 10): Int = (math.round(bins * urlCount / size) / bins.toDouble).toInt def listingsCount: Int = allListings.size def listingsFraq(bins: Int = 10): Int = (math.round(bins * listingsCount / size) / bins.toDouble).toInt def sizeBin = math.floor(math.log(size)).toInt // filtering methods def replaceNumbers(replacement: String = "_num_") = TextSequence(tokens.map { input => input.replaceAll("[0-9]+", replacement).replaceAll(replacement + "\\s+" + replacement, replacement) }) def replaceHTMLEscapes(replacement: String = " ") = TextSequence(tokens.map { input => input.replaceAll("&[^;]+;", replacement) }) def removeHTMLTags() = TextSequence(tokens.map { input => input.replaceAll("<.*?>", " ") }) def replaceHTMLTags(replacement: String = " ") = TextSequence(tokens.map { input => input.replaceAll("<[^>]+>", " ") }) def replaceNonAlphaNumeric(replacement: String = " ") = TextSequence(tokens.map { input => input.replaceAll("[^a-zA-Z0-9\\.\\s\\-]+", replacement) }) def replaceNonAlphaNumericUnderscore(replacement: String = " ") = TextSequence(tokens.map { input => input.replaceAll("[^a-zA-Z0-9\\.\\s\\-_]+", replacement) }) def replaceNonAlpha(replacement: String = " ") = TextSequence(tokens.map { input => input.replaceAll("[^a-zA-Z]+", replacement) }) def collapseHyphens() = TextSequence(tokens.map { input => input.replaceAll("--+", "--") }) def collapseUnderscores() = TextSequence(tokens.map { input => input.replaceAll("__+", "__") }) def collapsePeriods() = TextSequence(tokens.map { input => input.replaceAll("\\.\\.+", "..") }) def stripPunctuation() = TextSequence(tokens.map { input => input.replaceAll("^[^A-Za-z0-0]+", "").replaceAll("[^A-Za-z0-9]+$", "") }) // compact any white space def collapse() = TextSequence(tokens.map { input => input.replaceAll("\\s+", " ") }) // remove any whitespace from the right of a string def rstrip() = TextSequence(tokens.map { input => input.replaceAll("\\s+$", "") }) // remove any whitespace from the left of a string def lstrip() = TextSequence(tokens.map { input => input.replaceAll("^\\s+", "") }) // remove any leading or trailing whitespace def strip() = TextSequence(tokens.map { input => (input.trim) }) // clean up any whitespace def wsclean() = strip().collapse() // remove any unprintable non-ASCII characters def removeUnprintables(input: String) = TextSequence(tokens.map { input => input.replaceAll("[^\\x20-\\x7E]", "") }) def collapseWhitespaceAndPunc = TextSequence(tokens.map { input => input.replaceAll("\\s+", " ") .replaceAll("[\\-]+", "-") .replaceAll("[\\.]+", ".") }) def ngrams(n: Int, glue: String = " ") = new TextSequence(tokens.sliding(n).map { x => x.mkString(glue) }.toList) def shingles(n: Int, whitespace: String = "_"): TextSequence = { val str = tokens.mkString(whitespace) TextSequence(str.sliding(n).toList) } def prependNameSpace(namespace: String) = new TextSequence(tokens.map { x => namespace + x }) def toList = tokens.toList def toBinaryLabeledInstance(label: Double): BinaryLabeledInstance = { toBinaryLabeledInstance(new BinaryLabel(label)) } def toBinaryLabeledInstance(label: BinaryLabel): BinaryLabeledInstance = { val instance = new BinaryLabeledInstance(label) tokens.foreach { x => instance.addTerm(x) } instance } def toMulticlassLabeledInstance(label: MulticlassLabel): MulticlassLabeledInstance = { val instance = new MulticlassLabeledInstance(label) tokens.foreach { x => instance.addTerm(x) } instance } } object Stopwords { def apply(input: String): Boolean = stopwords.contains(input) val stopwords = Set("a", "as", "able", "about", "above", "according", "accordingly", "across", "actually", "after", "afterwards", "again", "against", "aint", "all", "allow", "allows", "almost", "alone", "along", "already", "also", "although", "always", "am", "among", "amongst", "amoungst", "amount", "an", "and", "another", "any", "anybody", "anyhow", "anyone", "anything", "anyway", "anyways", "anywhere", "apart", "appear", "appreciate", "appropriate", "are", "arent", "around", "as", "aside", "ask", "asking", "associated", "at", "available", "away", "awfully", "b", "back", "be", "became", "because", "become", "becomes", "becoming", "been", "before", "beforehand", "behind", "being", "believe", "below", "beside", "besides", "best", "better", "between", "beyond", "bill", "both", "bottom", "brief", "but", "by", "c", "cmon", "cs", "call", "came", "can", "cant", "cannot", "cant", "cause", "causes", "certain", "certainly", "changes", "clearly", "co", "com", "come", "comes", "con", "concerning", "consequently", "consider", "considering", "contain", "containing", "contains", "corresponding", "could", "couldnt", "couldnt", "course", "cry", "currently", "d", "de", "definitely", "describe", "described", "despite", "detail", "did", "didnt", "different", "do", "does", "doesnt", "doing", "dont", "done", "down", "downwards", "due", "during", "e", "each", "edu", "eg", "eight", "either", "eleven", "else", "elsewhere", "empty", "enough", "entirely", "especially", "et", "etc", "even", "ever", "every", "everybody", "everyone", "everything", "everywhere", "ex", "exactly", "example", "except", "f", "far", "few", "fifteen", "fifth", "fify", "fill", "find", "fire", "first", "five", "followed", "following", "follows", "for", "former", "formerly", "forth", "forty", "found", "four", "from", "front", "full", "further", "furthermore", "g", "get", "gets", "getting", "give", "given", "gives", "go", "goes", "going", "gone", "got", "gotten", "greetings", "h", "had", "hadnt", "happens", "hardly", "has", "hasnt", "hasnt", "have", "havent", "having", "he", "hes", "hello", "help", "hence", "her", "here", "heres", "hereafter", "hereby", "herein", "hereupon", "hers", "herself", "hi", "him", "himself", "his", "hither", "hopefully", "how", "howbeit", "however", "hundred", "i", "id", "ill", "im", "ive", "ie", "if", "ignored", "immediate", "in", "inasmuch", "inc", "indeed", "indicate", "indicated", "indicates", "inner", "insofar", "instead", "interest", "into", "inward", "is", "isnt", "it", "itd", "itll", "its", "its", "itself", "j", "just", "k", "keep", "keeps", "kept", "know", "known", "knows", "l", "last", "lately", "later", "latter", "latterly", "least", "less", "lest", "let", "lets", "like", "liked", "likely", "little", "look", "looking", "looks", "ltd", "m", "made", "mainly", "many", "may", "maybe", "me", "mean", "meanwhile", "merely", "might", "mill", "mine", "more", "moreover", "most", "mostly", "move", "much", "must", "my", "myself", "n", "name", "namely", "nd", "near", "nearly", "necessary", "need", "needs", "neither", "never", "nevertheless", "new", "next", "nine", "no", "nobody", "non", "none", "noone", "nor", "normally", "not", "nothing", "novel", "now", "nowhere", "o", "obviously", "of", "off", "often", "oh", "ok", "okay", "old", "on", "once", "one", "ones", "only", "onto", "or", "other", "others", "otherwise", "ought", "our", "ours", "ourselves", "out", "outside", "over", "overall", "own", "p", "part", "particular", "particularly", "per", "perhaps", "placed", "please", "plus", "possible", "presumably", "probably", "provides", "put", "q", "que", "quite", "qv", "r", "rather", "rd", "re", "really", "reasonably", "regarding", "regardless", "regards", "relatively", "respectively", "right", "s", "said", "same", "saw", "say", "saying", "says", "second", "secondly", "see", "seeing", "seem", "seemed", "seeming", "seems", "seen", "self", "selves", "sensible", "sent", "serious", "seriously", "seven", "several", "shall", "she", "should", "shouldnt", "show", "side", "since", "sincere", "six", "sixty", "so", "some", "somebody", "somehow", "someone", "something", "sometime", "sometimes", "somewhat", "somewhere", "soon", "sorry", "specified", "specify", "specifying", "still", "sub", "such", "sup", "sure", "system", "t", "ts", "take", "taken", "tell", "ten", "tends", "th", "than", "thank", "thanks", "thanx", "that", "thats", "thats", "the", "thea", "their", "theirs", "them", "themselves", "then", "thence", "there", "theres", "thereafter", "thereby", "therefore", "therein", "theres", "thereupon", "these", "they", "theyd", "theyll", "theyre", "theyve", "thickv", "thin", "think", "third", "this", "thorough", "thoroughly", "those", "though", "three", "through", "throughout", "thru", "thus", "to", "together", "too", "took", "top", "toward", "towards", "tried", "tries", "truly", "try", "trying", "twelve", "twenty", "twice", "two", "u", "un", "under", "unfortunately", "unless", "unlikely", "until", "unto", "up", "upon", "us", "use", "used", "useful", "uses", "using", "usually", "uucp", "v", "value", "various", "very", "via", "viz", "vs", "w", "want", "wants", "was", "wasnt", "way", "we", "wed", "well", "were", "weve", "welcome", "well", "went", "were", "werent", "what", "whats", "whatever", "when", "whence", "whenever", "where", "wheres", "whereafter", "whereas", "whereby", "wherein", "whereupon", "wherever", "whether", "which", "while", "whither", "who", "whos", "whoever", "whole", "whom", "whose", "why", "will", "willing", "wish", "with", "within", "without", "wont", "wonder", "would", "wouldnt", "x", "y", "yes", "yet", "you", "youd", "youll", "youre", "youve", "your", "yours", "yourself", "yourselves", "z", "zero") } object BadWords { def apply(input: String): Boolean = badwords.contains(input) val badwords = Set("ahole", "arse", "ass", "asshole", "asswipe", "bastard", "batty", "bender", "bitch", "bloody", "bollocks", "boner", "bumboy", "bugger", "coon", "cock", "cocksucker", "cracker", "crap", "cumsucker", "cunt", "damn", "dick", "dildo", "douchebag", "faggot", "fistfucker", "fuck", "fucker", "fuckwit", "fucktwat", "gaylord", "ho", "honky", "jackass", "jism", "joey", "knobcheese", "minge", "minger", "mong", "motherfucker", "munter", "pickle", "piss", "piss", "prick", "pussy", "rimmer", "schmuck", "shit", "slut", "spakka", "spaz", "skank", "taint", "tit", "tool", "tosser", "twat", "whore", "wanker") } ================================================ FILE: src/main/scala/com/etsy/scalding/jobs/conjecture/AdHocClassifier.scala ================================================ package com.etsy.scalding.jobs.conjecture import com.twitter.scalding.{Args, Job, Mode, SequenceFile, Tsv} import com.etsy.conjecture.scalding.evaluate.BinaryCrossValidator import com.etsy.conjecture.scalding.train.BinaryModelTrainer import com.etsy.conjecture.data.{BinaryLabel,BinaryLabeledInstance,StringKeyedVector} import com.etsy.conjecture.model.UpdateableLinearModel import com.google.gson.Gson import cascading.tuple.Fields class AdHocClassifier(args : Args) extends Job(args) { val input = args.getOrElse("input", "specify_an_input_dir") val out_dir = args.getOrElse("out_dir", "adhoc_classifier") val folds = args.getOrElse("folds", "0").toInt val problemName = args.getOrElse("name", "demo_problem") val xmx = args.getOrElse("xmx", "3").toInt val containerMemory = (xmx * 1024 * 1.16).toInt // Let the user configure the field names on the command line. val data_field_names = args.getOrElse("data_fields", "instance").split(",") val data_fields = data_field_names.tail.foldLeft(new Fields(data_field_names.head)) { (x,y) => x.append(new Fields(y)) } val instance_field = Symbol(args.getOrElse("instance_field", "instance")) // assumes input instances are a sequence file val instances = SequenceFile(input, data_fields).project(instance_field) val model_pipe = new BinaryModelTrainer(args) .train(instances, instance_field, 'model) model_pipe .write(SequenceFile(out_dir + "/model")) .mapTo('model -> 'json) { x : UpdateableLinearModel[BinaryLabel] => new Gson().toJson(x) } .write(Tsv(out_dir + "/model_json")) if(folds > 0) { val eval_pred = new BinaryCrossValidator(args, folds) .crossValidateWithPredictions(instances, instance_field, 'pred) eval_pred._1 .write(Tsv(out_dir + "/xval")) eval_pred._2 .write(SequenceFile(out_dir + "/pred")) } override def config = super.config ++ Map("mapred.child.java.opts" -> "-Xmx%dG".format(xmx), "mapreduce.map.memory.mb" -> containerMemory.toString, "mapreduce.reduce.memory.mb" -> containerMemory.toString ) } ================================================ FILE: src/main/scala/com/etsy/scalding/jobs/conjecture/AdHocClusterer.scala ================================================ package com.etsy.scalding.jobs.conjecture import com.twitter.scalding.{Args, Job, Mode, SequenceFile, Tsv} import com.etsy.conjecture.data.StringKeyedVector import cascading.pipe.Pipe import com.twitter.scalding._ import com.etsy.conjecture.data._ import com.etsy.conjecture.model._ import scala.collection.JavaConversions._ import java.io.File import scala.io.Source /** * Implements kmeans|| as described here: http://theory.stanford.edu/~sergei/papers/vldb12-kmpar.pdf * Also includes fast L1 projection step to find sparse cluster centers as described here: * http://www.eecs.tufts.edu/~dsculley/papers/fastkmeans.pdf * * Usage: * --curr_iter : Set the current iteration. * --num_starting_centers : Number of starting points to select at random to initialize C. * --init_iters : The number of initial iterations to do to find C oversampled centers. * --finish_iters : The number of iterations to cluster the C oversampled centers into * K starting centers. * --oversampling_factor : The number of points to oversample on each iteration of the * parallel kmeans initialization, described as a fraction * of the number of centers. * --kmeans_iters : The number of iterations to cluster the original dataset. * --input : Path on hdfs to the dataset to be clustered. Dataset should be a pipe * of (id_field : String, instance_field : StringKeyedVector). * --out_dir : Path where intermediate data, final cluster centers, and assignments * will be written. * --id_field : Symbol for the id of the point being clustered, (e.g. doc_id). * --instance_field : Symbol for the point being clustered, (e.g. document). * --sparsify : Whether or not to enforce cluster center sparsity. * --ball_radius : Radius of ball to project cluster centers on to in l1 projection. * E.g. 10^-1 == more sparse, 10^2 == less sparse. * --error_tolerance : Error tolerance in the e-accurate l1 projection. */ class AdHocClustererTest(args: Args) extends Job(args) { val curr_iter = args.getOrElse("curr_iter","0").toInt val num_starting_centers = args.getOrElse("num_starting_centers","10").toInt val init_iters = args.getOrElse("init_iters","5").toInt val finish_iters = args.getOrElse("finish_iters","5").toInt val oversampling_factor = args.getOrElse("oversampling_factor","1.0").toDouble val kmeans_iters = args.getOrElse("kmeans_iters","5").toInt val input = args.getOrElse("input", "specify_an_input_dir") val out_dir = args.getOrElse("out_dir", "specify_an_output_dir")+"/" val id_field = Symbol(args.getOrElse("id_field", "id")) val instance_field = Symbol(args.getOrElse("instance_field", "instance")) val xmx = args.getOrElse("xmx", "3").toInt val containerMemory = (xmx * 1024 * 1.16).toInt val max_finish_iters = init_iters + finish_iters val total_iter = init_iters + finish_iters + kmeans_iters /* * Number of clusters to build */ val num_clusters = args.getOrElse("num_clusters","100").toInt /* * Number of centers to oversample in the initialization phase */ val take_per_round = math.floor(num_clusters * oversampling_factor).toInt /* * Whether or not to enforce cluster center sparsity via l1 projection */ val sparsify = args.getOrElse("sparsify","true").toBoolean /* * Error tolerance for the l1 projection */ val error_tolerance = args.getOrElse("error_tolerance","0.01").toDouble /* * Ball radius for the l1 projection */ val ball_radius = args.getOrElse("ball_radius","10.0").toDouble /** * Read in the pipe of data to be clustered */ lazy val instances = SequenceFile(input, (id_field, instance_field)).read /** * Define centers based on the current iteration */ var centers : Pipe = if(curr_iter == 0){ /** * First iteration: Select some starting centers at random from the dataset */ instances .map(instance_field -> 'rand){ i : StringKeyedVector => math.random} .groupAll{_.sortWithTake[(Double, StringKeyedVector)](('rand, instance_field) -> 'list, num_starting_centers){(a, b) => a._1 > b._1}} .map('list -> 'centers){ l : List[(Double,StringKeyedVector)] => l.map(i => i._2) } .project('centers) } else { /** * If curr_iter <= init_iters, do kmeans|| iterations. * * If init_iters < curr_iter <= max_finish_iters, cluster the oversampled set of initial centers * into num_clusters true initial centers. * * If max_finish_iter < curr_iter <= total_iter, cluster the initial dataset using * the clusters obtained from previous steps as initial centers. */ SequenceFile(out_dir+"iter_"+(curr_iter - 1)+"/centers", ('centers)).read } lazy val oversampled_cluster_centers = SequenceFile(out_dir+"iter_"+(init_iters - 1)+"/centers", ('centers)).read .flattenTo[StringKeyedVector]('centers -> 'center) .rename('center -> instance_field) val new_centers = if (curr_iter < init_iters) { /** Over sample (oversampling_factor * num_clusters) factors **/ kmeansPlusPlusIter(instances, centers, take_per_round) } else if (curr_iter == init_iters) { /** Get ready to cluster oversampled factors by doing a kmeans++ pass over them **/ val init_final_centers = kmeansPlusPlusReclusterInit(centers, num_clusters) kmeansIter(oversampled_cluster_centers, init_final_centers, num_clusters, instance_field, curr_iter) } else if (curr_iter <= max_finish_iters) { /** Recluster oversampled centers into num_clusters final centers **/ kmeansIter(oversampled_cluster_centers, centers, num_clusters, instance_field, curr_iter) } else { /** Cluster the original dataset **/ kmeansIter(instances, centers, num_clusters, instance_field, curr_iter) } /** * If it's the last iteration, flatten the centers and * write them out; optionally generate cluster assignments * for each instance. Else, write the centers map out at * the end of each iteration. */ if(curr_iter == total_iter){ new_centers .flattenTo[StringKeyedVector]('centers -> 'center) .project('center) .write(SequenceFile(out_dir+"centers")) if(args.boolean("generate_assignments")) { instances .crossWithTiny(new_centers) .map((instance_field, 'centers) -> 'cluster_assignment){ i : (StringKeyedVector, Map[String, StringKeyedVector]) => assignCluster(i._1, i._2) } .project(id_field, 'cluster_assignment) .write(SequenceFile(out_dir+"assignments")) } } else { new_centers.write(SequenceFile(out_dir+"iter_"+curr_iter+"/centers")) } /** * Kmeans|| initialization: * C <- sample some points uniformly at random from instances * For init_iters: * C' <- top take_per_round points in instances by distance to current centers C * C <- union(C,C') */ def kmeansPlusPlusIter(instances : Pipe, centers : Pipe, take_per_round : Int) : Pipe = { /** Get each points' distance to it's nearest cluster center **/ val closest_distances = instances .crossWithTiny(centers) .map((instance_field, 'centers) -> 'closest_distance){ in : (StringKeyedVector, List[StringKeyedVector]) => distanceToClosestCenter(in._1, in._2) } /** Sum all closest distances into a normalizer **/ val normalizer = closest_distances .groupAll{ _.sum[Double]('closest_distance -> 'denominator) } /** * Normalize each points' distance to it's nearest cluster center to a probability. * Take the top take_per_round descending points as new centers. */ val top_by_distance = closest_distances .crossWithTiny(normalizer) .map(('closest_distance, 'denominator) -> 'normalized_distance){ i : (Double, Double) => i._1 / i._2 } .groupAll{ _.sortWithTake[(Double, StringKeyedVector)](('normalized_distance, instance_field) -> 'top_by_distance, take_per_round){(a, b) => a._1 > b._1} } .flattenTo[(Double, StringKeyedVector)]('top_by_distance -> ('distance, instance_field)) .project(instance_field) /** Union the set of new centers and old centers **/ val new_centers = ((top_by_distance.rename(instance_field -> 'center)) ++ (centers.flattenTo[StringKeyedVector]('centers -> 'center))) .groupAll{ _.toList[StringKeyedVector]('center -> 'centers) } new_centers } /** * Recluster the points in C into the final num_clusters kmeans|| centers. */ def kmeansPlusPlusReclusterInit(centers : Pipe, num_clusters : Int) : Pipe = { centers .map('centers -> 'centers){ data : List[StringKeyedVector] => val rand_idx = scala.util.Random.nextInt(data.size) val starting_C = data(rand_idx) var starting_centers = List(starting_C) val init_centers = kmeansPlusPlusInit(num_clusters, data, starting_centers) init_centers.zipWithIndex.map(i=> (i._2.toString,i._1)).toMap } } /** * Takes a pipe of points to cluster, a pipe of grouped clusters */ def kmeansIter(data : Pipe, centers : Pipe, K : Int, point_sym : Symbol, iter : Int) : Pipe = { val data_with_centers = data.crossWithTiny(centers) val cluster_assignments = data_with_centers .map((point_sym, 'centers) -> 'assignment){ fields : (StringKeyedVector, Map[String,StringKeyedVector]) => val (point, centers) = fields assignCluster(point, centers) } .project(point_sym, 'assignment) val grouped = cluster_assignments .groupBy('assignment){ _.size('denom).reduce[StringKeyedVector](point_sym){(a, b) => a.add(b); a} } .map((point_sym, 'denom) -> 'cluster){ fields : (StringKeyedVector, Double) => var (centroid, denom) = fields centroid.mul(1.0/denom) if(sparsify){ l1Projection(centroid, error_tolerance, ball_radius) } centroid } .project('assignment, 'cluster) val debug = grouped .map('cluster -> 'top){ i : StringKeyedVector => i.getMap().toList.sortBy(_._2).reverse.take(100).map(i => i._1).mkString(" ") } .project('assignment, 'top) .write(SequenceFile(out_dir+"debug/iter_"+iter+"_top_terms")) grouped .groupAll{ _.toList[(String,StringKeyedVector)](('assignment, 'cluster) -> 'centers) } .map('centers -> 'centers){ l : List[(String, StringKeyedVector)] => l.toMap } } /** * Generates initial centers for kmeans clustering to speed up convergence. * See more here: http://ilpubs.stanford.edu:8090/778/1/2006-13.pdf */ def kmeansPlusPlusInit(iters : Int, data : List[StringKeyedVector], centers : List[StringKeyedVector]) : List[StringKeyedVector] = { var new_centers = centers var temp_data = data (0 until iters).foreach{ iter => val dists = temp_data.map(i => (i, distanceToClosestCenter(i, centers))) val norm = dists.map(i => i._2).sum val x = dists.map(i => (i._1, i._2/norm)).sortBy(_._2).reverse.map(i=>i._1).take(1)(0) new_centers = (new_centers ++ List(x)).toSet.toList temp_data = temp_data.filter(i => i != x) } new_centers } /** * Returns the cosine distance between a point and its closest center. */ def distanceToClosestCenter(point : StringKeyedVector, centers : List[StringKeyedVector]) : Double = { centers.map(center => computeDistance(point, center)).min } /** * Computes the cosine distance between a point and a center. */ def computeDistance(point : StringKeyedVector, center : StringKeyedVector) : Double = { val dot_product = point.dot(center) val point_magnitude = point.LPNorm(2.0) val center_magnitude = center.LPNorm(2.0) 1.0 - (dot_product/(point_magnitude*center_magnitude)) } /** * Assign a point to its nearest cluster center by cosine distance. */ def assignCluster(point : StringKeyedVector, centers : Map[String,StringKeyedVector]) : String = { val distances = centers.toList.map(i => (i._1, computeDistance(point, i._2))) distances.minBy{_._2}._1 } /** * e-Accurate Projection to L1 ball for sparse cluster centers */ def l1Projection(center : StringKeyedVector, e : Double = 0.01, lambda : Double = 1.0) : StringKeyedVector = { val l1Norm = center.LPNorm(1.0) if (l1Norm <= lambda + e) { center } else { var upper = center.max() var lower = 0.0 var current = l1Norm var theta = 0.0 while (current > lambda*(1+e) || current < lambda) { theta = (upper + lower) / 2.0 current = center.values().map(i => math.max(0.0, math.abs(i)-theta)).sum if(current <= lambda){ upper = theta } else { lower = theta } } var sparse_center = new StringKeyedVector() center.getMap() .map(i => (i._1, math.signum(i._2) * math.max(0.0, math.abs(i._2) - theta))) .filter(i => i._2 != 0.0) .foreach{ i => sparse_center.setCoordinate(i._1, i._2)} sparse_center } } override def next : Option[Job] = { val new_args = args + ("curr_iter", Some((curr_iter+1).toString)) if(curr_iter < total_iter) { Some(clone(new_args)) } else { None } } override def config = super.config ++ Map("mapred.child.java.opts" -> "-Xmx%dG".format(xmx), "mapreduce.map.memory.mb" -> containerMemory.toString, "mapreduce.reduce.memory.mb" -> containerMemory.toString ) } ================================================ FILE: src/main/scala/com/etsy/scalding/jobs/conjecture/AdHocMulticlassClassifier.scala ================================================ package com.etsy.scalding.jobs.conjecture import com.twitter.scalding.{Args, Job, Mode, SequenceFile, Tsv} import com.etsy.conjecture.scalding.evaluate.MulticlassCrossValidator import com.etsy.conjecture.scalding.train.MulticlassModelTrainer import com.etsy.conjecture.data.{MulticlassLabeledInstance, StringKeyedVector} import com.etsy.conjecture.model.UpdateableMulticlassLinearModel import com.google.gson.Gson import cascading.tuple.Fields class AdHocMulticlassClassifier(args : Args) extends Job(args) { val input = args("input") val out_dir = args("out_dir") val folds = args.getOrElse("folds", "0").toInt val categories = args("categories").split(",").toArray val xmx = args.getOrElse("xmx", "3").toInt val containerMemory = (xmx * 1024 * 1.16).toInt // Let the user configure the field names on the command line. val data_field_names = args.getOrElse("data_fields", "instance").split(",") val data_fields = data_field_names.tail.foldLeft(new Fields(data_field_names.head)) { (x,y) => x.append(new Fields(y)) } val instance_field = Symbol(args.getOrElse("instance_field", "instance")) // assumes input instances are a sequence file val instances = SequenceFile(input, data_fields).project(instance_field) val model_pipe = new MulticlassModelTrainer(args, categories) .train(instances, instance_field, 'model) model_pipe .write(SequenceFile(out_dir + "/model")) .mapTo('model -> 'json) { x : UpdateableMulticlassLinearModel => new Gson().toJson(x) } .write(Tsv(out_dir + "/model_json")) if(folds > 0) { val eval_pred = new MulticlassCrossValidator(args, folds, categories) .crossValidateWithPredictions(instances, instance_field, 'pred) eval_pred._1 .write(Tsv(out_dir + "/xval")) eval_pred._2 .write(SequenceFile(out_dir + "/pred")) } override def config = super.config ++ Map("mapred.child.java.opts" -> "-Xmx%dG".format(xmx), "mapreduce.map.memory.mb" -> containerMemory.toString, "mapreduce.reduce.memory.mb" -> containerMemory.toString) } ================================================ FILE: src/main/scala/com/etsy/scalding/jobs/conjecture/AdHocPredictor.scala ================================================ package com.etsy.scalding.jobs.conjecture import com.twitter.scalding.{Args, Job, Mode, SequenceFile, Tsv} import com.etsy.conjecture.scalding.evaluate.BinaryEvaluator import com.etsy.conjecture.data.{BinaryLabeledInstance, BinaryLabel} import com.etsy.conjecture.model.UpdateableLinearModel import com.google.gson.Gson import cascading.tuple.Fields class AdHocPredictor(args : Args) extends Job(args) { val input = args.getOrElse("input", "specify_an_input_dir") val out_dir = args.getOrElse("out_dir", "adhoc_classifier") val model = args.getOrElse("model", "specify a model") val problemName = args.getOrElse("name", "demo_problem") val xmx = args.getOrElse("xmx", "3").toInt val skipFinalSort = args.boolean("skip_final_sort") val containerMemory = (xmx * 1024 * 1.16).toInt // Let the user configure the field names on the command line. val data_field_names = args.getOrElse("data_fields", "instance").split(",") val data_fields = data_field_names.tail.foldLeft(new Fields(data_field_names.head)) { (x,y) => x.append(new Fields(y)) } val model_field = new Fields(args.getOrElse("model_field", "model")) val instance_field = new Fields(args.getOrElse("instance_field", "instance")) val instances = SequenceFile(input, data_fields).read.project(instance_field) val model_pipe = SequenceFile(model, model_field).read val predictions = instances.crossWithTiny(model_pipe) .map((model_field, instance_field) -> ('pred, 'explain)) { x : (UpdateableLinearModel[BinaryLabel], BinaryLabeledInstance) => (x._1.predict(x._2.getVector), x._1.explainPrediction(x._2.getVector)) } .discard(model_field) .map(instance_field -> 'supporting_data) { x : BinaryLabeledInstance => x.getSupportingData() } .project('supporting_data, 'pred) .map('pred -> 'pred) { in : BinaryLabel => in.getValue() } val output = if (skipFinalSort) predictions else predictions.groupAll { _.sortBy('pred).reverse } output.write(SequenceFile(out_dir + "/pred")) override def config = super.config ++ Map("mapred.child.java.opts" -> "-Xmx%dG".format(xmx), "mapreduce.map.memory.mb" -> containerMemory.toString, "mapreduce.reduce.memory.mb" -> containerMemory.toString ) } ================================================ FILE: src/main/scala/com/etsy/scalding/jobs/conjecture/NNMFTest.scala ================================================ package com.etsy.scalding.jobs.conjecture import com.etsy.conjecture.scalding.NNMF import com.twitter.scalding.{Args, Job, Tsv, SequenceFile} import org.apache.commons.math3.linear.RealVector /* * Job to do NNMF of the supplied matrix, given via the arg "A" * "alpha" is the extra weight given to non-zero entries. */ class NNMFTest(args : Args) extends Job(args) { val iter = args.getOrElse("iter", "0").toInt val iters = args.getOrElse("iters", "20").toInt val base_dir = args.getOrElse("base_dir", "nnmf_test") val A_path = args.getOrElse("A", "critics.tsv") val alpha = args.getOrElse("alpha", "0.0").toDouble val A = Tsv(A_path, ('row, 'col, 'val)) .map('val -> 'val){v : String => v.toDouble} val HW = if(iter == 0) { // just initialize NNMF.initGaussian(A, 10) } else { // Last iterations output. (SequenceFile(base_dir + "/H/" + (iter-1), ('row, 'vec, 'bias)).read, SequenceFile(base_dir + "/W/" + (iter-1), ('col, 'vec, 'bias)).read) } val HW_ = NNMF.updateGaussianWeighted(A, HW._1, HW._2, alpha) HW_._1.write(SequenceFile(base_dir + "/H/" + iter)) HW_._2.write(SequenceFile(base_dir + "/W/" + iter)) HW._1.crossWithSmaller(HW._2.rename('vec -> 'vec2).rename('bias -> 'bias2)) .map(('vec, 'vec2, 'bias, 'bias2) -> 'pred){x : (RealVector, RealVector, Double, Double) => x._1.dotProduct(x._2) + x._3 + x._4} .project('row, 'col, 'pred) .joinWithSmaller(('row, 'col) -> ('row_, 'col_), A.rename(('row, 'col) -> ('row_, 'col_)), new cascading.pipe.joiner.OuterJoin()) .mapTo(('val, 'pred) -> 'err){x : (Double, Double) => val d = x._1 - x._2; (if(x._1 == 0.0) 1.0 else (1.0 + alpha)) * d * d} .groupAll{_.average('err)} .write(Tsv(base_dir+"/err/"+iter)) // Start more iterations possibly. override def next : Option[Job] = { val new_args = args + (("iter", Some((iter+1).toString))) if(iter < iters - 1) { Some(clone(new_args)) } else { None } } } ================================================ FILE: src/test/java/com/etsy/conjecture/data/LazyVectorTest.java ================================================ package com.etsy.conjecture.data; import java.io.ByteArrayInputStream; import java.io.ByteArrayOutputStream; import java.io.ObjectInputStream; import java.io.ObjectOutputStream; import com.esotericsoftware.kryo.Kryo; import com.esotericsoftware.kryo.io.Input; import com.esotericsoftware.kryo.io.Output; import com.google.gson.Gson; import static org.junit.Assert.assertEquals; import static org.junit.Assert.assertTrue; import static org.junit.Assert.assertFalse; import org.junit.Test; public class LazyVectorTest { final double eps = 0.000001; // Update function to use for testing. // Decay the parameters over time. final static LazyVector.UpdateFunction uf = new LazyVector.UpdateFunction() { private static final long serialVersionUID = 1019666879466468375L; public double lazyUpdate(String k, double p, long a, long b) { return p * Math.pow(0.9, b - a); } }; // Build an SKV in a way which exercises a bunch of different code. public LazyVector buildLV() { LazyVector lv = new LazyVector(uf); lv.setCoordinate("foo", 1.0); lv.addToCoordinate("bar", -2.0); lv.addToCoordinate("baz", 0.0); lv.setCoordinate("dave", 5.0); lv.deleteCoordinate("dave"); return lv; } /** * Basic testing of coordinate getting and setting. */ @Test public void testCoordinates() { LazyVector lv = buildLV(); assertEquals(2, lv.size()); assertEquals(1.0, lv.getCoordinate("foo"), eps); assertEquals(-2.0, lv.getCoordinate("bar"), eps); assertEquals(0.0, lv.getCoordinate("baz"), eps); assertEquals(0.0, lv.getCoordinate("dave"), eps); assertEquals(0.0, lv.getCoordinate("test"), eps); } /** * Basic testing of lazy updating. */ @Test public void testCoordinatesLazy() { LazyVector lv = buildLV(); lv.incrementIteration(); assertEquals(2, lv.size()); assertEquals(0.9, lv.getCoordinate("foo"), eps); assertEquals(-1.8, lv.getCoordinate("bar"), eps); assertEquals(0.0, lv.getCoordinate("baz"), eps); assertEquals(0.0, lv.getCoordinate("dave"), eps); assertEquals(0.0, lv.getCoordinate("test"), eps); lv.setCoordinate("bar", 2.0); lv.incrementIteration(); assertEquals(2, lv.size()); assertEquals(0.81, lv.getCoordinate("foo"), eps); assertEquals(1.8, lv.getCoordinate("bar"), eps); } /** * Test addScaled. */ @Test public void testAddScaledToSKV() { LazyVector lv = buildLV(); StringKeyedVector accum = new StringKeyedVector(); accum.addScaled(lv, 2.0); assertEquals(2, accum.size()); assertEquals(2.0, accum.getCoordinate("foo"), eps); assertEquals(-4.0, accum.getCoordinate("bar"), eps); lv.incrementIteration(); accum.addScaled(lv, -2.0); assertEquals(2, accum.size()); assertEquals(0.2, accum.getCoordinate("foo"), eps); assertEquals(-0.4, accum.getCoordinate("bar"), eps); } /** * Test addScaled. */ @Test public void testAddScaledToLV() { LazyVector lv = buildLV(); LazyVector accum = new LazyVector(uf); accum.setCoordinate("foo", 10.0); accum.incrementIteration(); accum.incrementIteration(); accum.incrementIteration(); // foo is now 7.29 accum.addScaled(lv, 2.0); assertEquals(2, accum.size()); assertEquals(9.29, accum.getCoordinate("foo"), eps); assertEquals(-4.0, accum.getCoordinate("bar"), eps); lv.incrementIteration(); // foo is now 0.9 accum.incrementIteration(); // foo is now 8.361 accum.addScaled(lv, -2.0); assertEquals(1, accum.size()); assertEquals(6.561, accum.getCoordinate("foo"), eps); } /** * Test addScaled. */ @Test public void testAddScaledToSelf() { LazyVector lv = buildLV(); lv.incrementIteration(); lv.incrementIteration(); lv.addScaled(lv, 1.0); assertEquals(2, lv.size()); assertEquals(1.0 * 0.81 * 2, lv.getCoordinate("foo"), eps); assertEquals(-2.0 * 0.81 * 2, lv.getCoordinate("bar"), eps); } /** * Test addScaled. */ @Test public void testAddScaledSKVToLV() { LazyVector accum = new LazyVector(uf); StringKeyedVector skv = new StringKeyedVector(); skv.setCoordinate("foo", 1.0); skv.setCoordinate("bar", 5.0); accum.addScaled(skv, 2.0); assertEquals(2, accum.size()); assertEquals(2.0, accum.getCoordinate("foo"), eps); assertEquals(10.0, accum.getCoordinate("bar"), eps); accum.incrementIteration(); accum.incrementIteration(); // foo: 1.62, bar: 8.10 accum.addScaled(skv, -1.0); assertEquals(2, accum.size()); assertEquals(0.62, accum.getCoordinate("foo"), eps); assertEquals(3.10, accum.getCoordinate("bar"), eps); } /** * Test the dot product. */ @Test public void testDotProduct() { LazyVector skv = buildLV(); StringKeyedVector skv2 = new StringKeyedVector(skv); assertEquals(5.0, skv.dot(skv), eps); skv.incrementIteration(); assertEquals(5.0 * 0.81, skv.dot(skv), eps); skv2.addToCoordinate("baz", -10.0); assertEquals(5.0 * 0.9, skv.dot(skv2), eps); } /** * Test freezing the keys. */ @Test public void testFreezing() { LazyVector skv = buildLV(); skv.setFreezeKeySet(true); skv.addToCoordinate("fake", 1.0); assertEquals(2, skv.size()); skv.setCoordinate("fake2", 2.0); assertEquals(2, skv.size()); skv.setFreezeKeySet(false); skv.setCoordinate("fake2", 2.0); assertEquals(3, skv.size()); } /** * Test java serialization. */ @Test public void testJavaSerialization() throws Exception { LazyVector skv = buildLV(); skv.incrementIteration(); // Serialize to a byte array in ram. ByteArrayOutputStream bos = new ByteArrayOutputStream(); ObjectOutputStream oos = new ObjectOutputStream(bos); oos.writeObject(skv); oos.flush(); // Deserialize. ByteArrayInputStream bis = new ByteArrayInputStream(bos.toByteArray()); ObjectInputStream ois = new ObjectInputStream(bis); LazyVector des = (LazyVector)ois.readObject(); assertFalse(des.getFreezeKeySet()); assertEquals(2, des.size()); assertEquals(0.9, des.getCoordinate("foo"), eps); assertEquals(-1.8, des.getCoordinate("bar"), eps); des.incrementIteration(); assertEquals(2, des.size()); assertEquals(0.81, des.getCoordinate("foo"), eps); assertEquals(-1.62, des.getCoordinate("bar"), eps); } /** * Test kryo serialization. */ @Test public void testKryoSerialization() throws Exception { LazyVector skv = buildLV(); skv.incrementIteration(); // Serialize to a byte array in ram. ByteArrayOutputStream bos = new ByteArrayOutputStream(); Output ko = new Output(bos); Kryo kry = new Kryo(); kry.writeObject(ko, skv); ko.flush(); // Deserialize. ByteArrayInputStream bis = new ByteArrayInputStream(bos.toByteArray()); Input ki = new Input(bis); LazyVector des = (LazyVector)kry.readObject(ki, LazyVector.class); assertFalse(des.getFreezeKeySet()); assertEquals(2, des.size()); assertEquals(0.9, des.getCoordinate("foo"), eps); assertEquals(-1.8, des.getCoordinate("bar"), eps); des.incrementIteration(); assertEquals(2, des.size()); assertEquals(0.81, des.getCoordinate("foo"), eps); assertEquals(-1.62, des.getCoordinate("bar"), eps); } /** * Make sure Gson serializes this thing properly. */ @Test public void testGson() { Gson gson = new Gson(); String json = gson.toJson(buildLV()); String vector1 = "\"vector\":{\"foo\":1.0,\"bar\":-2.0}"; String vector2 = "\"vector\":{\"bar\":-2.0,\"foo\":1.0}"; String fks = "\"freezeKeySet\":false"; assertTrue(json.contains(vector1) || json.contains(vector2)); assertTrue(json.contains(fks)); assertFalse(json.contains("iterations")); } } ================================================ FILE: src/test/java/com/etsy/conjecture/data/StringKeyedVectorTest.java ================================================ package com.etsy.conjecture.data; import java.io.ByteArrayInputStream; import java.io.ByteArrayOutputStream; import java.io.ObjectInputStream; import java.io.ObjectOutputStream; import com.esotericsoftware.kryo.Kryo; import com.esotericsoftware.kryo.io.Input; import com.esotericsoftware.kryo.io.Output; import com.google.gson.Gson; import static org.junit.Assert.assertEquals; import static org.junit.Assert.assertTrue; import static org.junit.Assert.assertFalse; import org.junit.Test; public class StringKeyedVectorTest { final double eps = 0.000001; // Build an SKV in a way which exercises a bunch of different code. public StringKeyedVector buildSKV() { StringKeyedVector skv = new StringKeyedVector(); skv.setCoordinate("foo", 1.0); skv.addToCoordinate("bar", -2.0); skv.addToCoordinate("baz", 0.0); skv.setCoordinate("dave", 5.0); skv.deleteCoordinate("dave"); return skv; } /** * Basic testing of coordinate getting and setting. */ @Test public void testCoordinates() { StringKeyedVector skv = buildSKV(); assertEquals(2, skv.size()); assertEquals(1.0, skv.getCoordinate("foo"), eps); assertEquals(-2.0, skv.getCoordinate("bar"), eps); assertEquals(0.0, skv.getCoordinate("baz"), eps); assertEquals(0.0, skv.getCoordinate("dave"), eps); assertEquals(0.0, skv.getCoordinate("test"), eps); } /** * Test addScaled. */ @Test public void testAddScaled() { StringKeyedVector skv = buildSKV(); StringKeyedVector accum = new StringKeyedVector(); skv.addScaled(accum, 1.0); accum.addScaled(skv, 2.0); assertEquals(2, accum.size()); assertEquals(2.0, accum.getCoordinate("foo"), eps); assertEquals(-4.0, accum.getCoordinate("bar"), eps); accum.addScaled(skv, -2.0); assertEquals(0, accum.size()); } /** * Test the dot product. */ @Test public void testDotProduct() { StringKeyedVector skv = buildSKV(); assertEquals(5.0, skv.dot(skv), eps); StringKeyedVector skv2 = new StringKeyedVector(skv); skv2.addToCoordinate("baz", -10.0); assertEquals(5.0, skv.dot(skv2), eps); assertEquals(105.0, skv2.dot(skv2), eps); } /** * Test freezing the keys. */ @Test public void testFreezing() { StringKeyedVector skv = buildSKV(); skv.setFreezeKeySet(true); skv.addToCoordinate("fake", 1.0); assertEquals(2, skv.size()); skv.setCoordinate("fake2", 2.0); assertEquals(2, skv.size()); skv.setFreezeKeySet(false); skv.setCoordinate("fake2", 2.0); assertEquals(3, skv.size()); } /** * Test java serialization. */ @Test public void testJavaSerialization() throws Exception { StringKeyedVector skv = buildSKV(); // Serialize to a byte array in ram. ByteArrayOutputStream bos = new ByteArrayOutputStream(); ObjectOutputStream oos = new ObjectOutputStream(bos); oos.writeObject(skv); oos.flush(); // Deserialize. ByteArrayInputStream bis = new ByteArrayInputStream(bos.toByteArray()); ObjectInputStream ois = new ObjectInputStream(bis); StringKeyedVector des = (StringKeyedVector)ois.readObject(); assertFalse(des.getFreezeKeySet()); assertEquals(2, des.size()); assertEquals(1.0, des.getCoordinate("foo"), eps); assertEquals(-2.0, des.getCoordinate("bar"), eps); assertEquals(0.0, des.getCoordinate("baz"), eps); assertEquals(0.0, des.getCoordinate("dave"), eps); assertEquals(0.0, des.getCoordinate("test"), eps); } /** * Test kryo serialization. */ @Test public void testKryoSerialization() throws Exception { StringKeyedVector skv = buildSKV(); // Serialize to a byte array in ram. ByteArrayOutputStream bos = new ByteArrayOutputStream(); Output ko = new Output(bos); Kryo kry = new Kryo(); kry.writeObject(ko, skv); ko.flush(); // Deserialize. ByteArrayInputStream bis = new ByteArrayInputStream(bos.toByteArray()); Input ki = new Input(bis); StringKeyedVector des = (StringKeyedVector)kry.readObject(ki, StringKeyedVector.class); assertFalse(des.getFreezeKeySet()); assertEquals(2, des.size()); assertEquals(1.0, des.getCoordinate("foo"), eps); assertEquals(-2.0, des.getCoordinate("bar"), eps); assertEquals(0.0, des.getCoordinate("baz"), eps); assertEquals(0.0, des.getCoordinate("dave"), eps); assertEquals(0.0, des.getCoordinate("test"), eps); } /** * Make sure Gson serializes this thing properly. */ @Test public void testGson() { Gson gson = new Gson(); String json = gson.toJson(buildSKV()); String vector1 = "\"vector\":{\"foo\":1.0,\"bar\":-2.0}"; String vector2 = "\"vector\":{\"bar\":-2.0,\"foo\":1.0}"; String fks = "\"freezeKeySet\":false"; assertTrue(json.contains(vector1) || json.contains(vector2)); assertTrue(json.contains(fks)); } } ================================================ FILE: src/test/java/com/etsy/conjecture/evaluation/TestReceiverOperatingCharacteristic.java ================================================ package com.etsy.conjecture.evaluation; import static org.junit.Assert.assertEquals; import org.junit.Test; public class TestReceiverOperatingCharacteristic { static double[] labels = { 1, 0, 1, 0, 0, 0, 1, 1, 0, 1, 0, 0, 0, 1, 1, 0, 0, 0 }; static double[] predictions = { 0.80962, 0.48458, 0.65812, 0.16117, 0.47375, 0.26587, 0.71517, 0.63866, 0.36296, 0.89639, 0.35936, 0.22413, 0.36402, 0.41459, 0.83148, 0.23271, 0.23271, 0.23271 }; // from scikit learn static double AUC = 0.97402597402597402; @Test public void testAUC() { ReceiverOperatingCharacteristic roc = new ReceiverOperatingCharacteristic(); for (int i = 0; i < labels.length; i++) { roc.add(labels[i], predictions[i]); } assertEquals(AUC, roc.binaryAUC(), 0.0000001); } } ================================================ FILE: src/test/java/com/etsy/conjecture/model/UpdateableLinearModelTest.java ================================================ package com.etsy.conjecture.model; import static org.junit.Assert.assertEquals; import static org.junit.Assert.assertTrue; import com.etsy.conjecture.data.StringKeyedVector; import org.junit.Test; import com.etsy.conjecture.data.BinaryLabeledInstance; public class UpdateableLinearModelTest { final double eps = 0.000001; final SGDOptimizer optimizer = new ElasticNetOptimizer(); BinaryLabeledInstance getPositiveInstance() { BinaryLabeledInstance bli = new BinaryLabeledInstance(1.0); bli.setCoordinate("foo", 1.0); bli.setCoordinate("bar", 2.0); return bli; } BinaryLabeledInstance getNegativeInstance() { BinaryLabeledInstance bli = new BinaryLabeledInstance(0.0); bli.setCoordinate("foo", 1.0); bli.setCoordinate("baz", -1.0); return bli; } @Test public void testLogisticRegressionBasic() { LogisticRegression slr = new LogisticRegression(optimizer); // perform one update and check parameter values. double eta = slr.optimizer.getDecreasingLearningRate(slr.epoch); slr.update(getPositiveInstance()); assertEquals(eta * 0.5, slr.getParam().getCoordinate("foo"), eps); assertEquals(eta * 1.0, slr.getParam().getCoordinate("bar"), eps); assertTrue(slr.predict(getPositiveInstance().getVector()).getValue() > 0.5); // perform a second update. slr.update(getNegativeInstance()); assertTrue(slr.predict(getPositiveInstance().getVector()).getValue() > 0.5); assertTrue(slr.predict(getNegativeInstance().getVector()).getValue() < 0.5); } @Test public void testLogisticRegressionLaplaceRegularization() { SGDOptimizer laplaceOptimizer = optimizer.setLaplaceRegularizationWeight(0.1); LogisticRegression slr = new LogisticRegression(laplaceOptimizer); // perform one update and check parameter values. double eta = slr.optimizer.getDecreasingLearningRate(slr.epoch); slr.update(getPositiveInstance()); assertEquals(eta * 0.5, slr.getParam().getCoordinate("foo"), eps); assertEquals(eta * 1.0, slr.getParam().getCoordinate("bar"), eps); double eta2 = slr.optimizer.getDecreasingLearningRate(slr.epoch); slr.update(getNegativeInstance()); assertEquals(eta * 1.0 - eta2 * 0.1, slr.getParam() .getCoordinate("bar"), eps); // update with a different example enough times to make bar -> 0. for (int i = 0; i < 10; i++) { slr.update(getNegativeInstance()); } assertEquals(2, slr.getParam().size()); assertEquals(0.0, slr.getParam().getCoordinate("bar"), eps); } @Test public void testLogisticRegressionGaussianRegularization() { SGDOptimizer gaussianOptimizer = optimizer.setGaussianRegularizationWeight(0.2); LogisticRegression slr = new LogisticRegression(gaussianOptimizer); // perform one update and check parameter values. double eta = slr.optimizer.getDecreasingLearningRate(slr.epoch); slr.update(getPositiveInstance()); assertEquals(eta * 0.5, slr.getParam().getCoordinate("foo"), eps); assertEquals(eta * 1.0, slr.getParam().getCoordinate("bar"), eps); double eta2 = slr.optimizer.getDecreasingLearningRate(slr.epoch); slr.update(getNegativeInstance()); assertEquals(eta * 1.0 * (1.0 - eta2 * 0.2), slr.getParam() .getCoordinate("bar"), eps); } @Test public void testPerceptronBasic() { Hinge p = new Hinge(optimizer).setThreshold(0.0); // perform one update and check parameter values. double eta = p.optimizer.getDecreasingLearningRate(p.epoch); p.update(getPositiveInstance()); assertEquals(eta * 1.0, p.getParam().getCoordinate("foo"), eps); assertEquals(eta * 2.0, p.getParam().getCoordinate("bar"), eps); assertTrue(p.predict(getPositiveInstance().getVector()).getValue() > 0.5); // perform a second update. p.update(getNegativeInstance()); assertTrue(p.predict(getPositiveInstance().getVector()).getValue() > 0.5); assertTrue(p.predict(getNegativeInstance().getVector()).getValue() < 0.5); } public void testInstanceNotModified(UpdateableLinearModel model) { BinaryLabeledInstance instance = getPositiveInstance(); StringKeyedVector instanceCopy = instance.getVector().copy(); model.update(instance); assertEquals(instance.getVector().getCoordinate("foo"), instanceCopy.getCoordinate("foo"), 0.0); assertEquals(instance.getVector().getCoordinate("bar"), instanceCopy.getCoordinate("bar"), 0.0); } @Test public void testInstanceNotModifiedByOptimizer() { ElasticNetOptimizer eOptimizer = new ElasticNetOptimizer(); LogisticRegression eModel = new LogisticRegression(eOptimizer); testInstanceNotModified(eModel); FTRLOptimizer ftrlOptimizer = new FTRLOptimizer(); LogisticRegression fModel = new LogisticRegression(ftrlOptimizer); testInstanceNotModified(fModel); AdagradOptimizer adagradOptimizer = new AdagradOptimizer(); LogisticRegression aModel = new LogisticRegression(adagradOptimizer); testInstanceNotModified(aModel); MIRA mModel = new MIRA(); testInstanceNotModified(mModel); } @Test public void testInstanceNotModifiedByModel() { LogisticRegression lrModel = new LogisticRegression(optimizer); testInstanceNotModified(lrModel); LeastSquaresRegressionModel lsModel = new LeastSquaresRegressionModel(optimizer); testInstanceNotModified(lsModel); Hinge hModel = new Hinge(optimizer); testInstanceNotModified(hModel); } }