Repository: etsy/Conjecture
Branch: master
Commit: a32d61966b12
Files: 117
Total size: 449.7 KB
Directory structure:
gitextract__actrptd/
├── .gitignore
├── .travis.yml
├── LICENSE.md
├── README.md
├── bin/
│ ├── demo.sh
│ ├── model_diff.py
│ ├── model_param.py
│ └── prediction_inspection.py
├── build.sbt
├── clients/
│ └── phplib/
│ └── Conjecture/
│ ├── BinaryClassifier.php
│ ├── Config.php
│ ├── ConjectureException.php
│ ├── Finder.php
│ ├── Instance.php
│ ├── MulticlassClassifier.php
│ ├── MulticlassLogisticRegressionClassifier.php
│ ├── MulticlassOneVsAllClassifier.php
│ ├── Text.php
│ ├── TextSequence.php
│ └── Vector.php
├── data/
│ └── iris.tsv
├── project/
│ ├── build.properties
│ └── plugins.sbt
├── sbt
└── src/
├── main/
│ ├── java/
│ │ └── com/
│ │ └── etsy/
│ │ └── conjecture/
│ │ ├── GenericPair.java
│ │ ├── PrimitivePair.java
│ │ ├── Utilities.java
│ │ ├── data/
│ │ │ ├── AbstractInstance.java
│ │ │ ├── BinaryLabel.java
│ │ │ ├── BinaryLabeledInstance.java
│ │ │ ├── ByteArrayDoubleHashMap.java
│ │ │ ├── ClusterLabel.java
│ │ │ ├── ClusterPrediction.java
│ │ │ ├── Instance.java
│ │ │ ├── InstanceFactory.java
│ │ │ ├── InstanceInterface.java
│ │ │ ├── Label.java
│ │ │ ├── LabeledInstance.java
│ │ │ ├── LazyVector.java
│ │ │ ├── MulticlassLabel.java
│ │ │ ├── MulticlassLabeledInstance.java
│ │ │ ├── MulticlassPrediction.java
│ │ │ ├── RealValueLabeledInstance.java
│ │ │ ├── RealValuedLabel.java
│ │ │ ├── Recommendation.java
│ │ │ └── StringKeyedVector.java
│ │ ├── evaluation/
│ │ │ ├── BinaryModelEvaluation.java
│ │ │ ├── ConfusionMatrix.java
│ │ │ ├── EvaluationAggregator.java
│ │ │ ├── ModelEvaluation.java
│ │ │ ├── MulticlassConfusionMatrix.java
│ │ │ ├── MulticlassModelEvaluation.java
│ │ │ ├── MulticlassReceiverOperatingCharacteristic.java
│ │ │ ├── ReceiverOperatingCharacteristic.java
│ │ │ └── RegressionModelEvaluation.java
│ │ ├── model/
│ │ │ ├── AdagradOptimizer.java
│ │ │ ├── ClusteringModel.java
│ │ │ ├── ControlOptimizer.java
│ │ │ ├── Decomposable.java
│ │ │ ├── ElasticNetOptimizer.java
│ │ │ ├── FTRLOptimizer.java
│ │ │ ├── Hinge.java
│ │ │ ├── KMeans.java
│ │ │ ├── LeastSquaresRegressionModel.java
│ │ │ ├── LogisticRegression.java
│ │ │ ├── MIRA.java
│ │ │ ├── MIRAOptimizer.java
│ │ │ ├── Model.java
│ │ │ ├── PassiveAggressiveOptimizer.java
│ │ │ ├── SGDOptimizer.java
│ │ │ ├── UpdateableLinearModel.java
│ │ │ ├── UpdateableModel.java
│ │ │ └── UpdateableMulticlassLinearModel.java
│ │ └── topics/
│ │ └── lda/
│ │ ├── LDADenseTopics.java
│ │ ├── LDADict.java
│ │ ├── LDADoc.java
│ │ ├── LDAPartialSparseTopics.java
│ │ ├── LDAPartialTopics.java
│ │ ├── LDARandomTopics.java
│ │ ├── LDASparseTopics.java
│ │ ├── LDATopics.java
│ │ └── LDAUtils.java
│ └── scala/
│ └── com/
│ └── etsy/
│ ├── conjecture/
│ │ ├── VWReader.scala
│ │ ├── demo/
│ │ │ ├── DemoLinearHyperparameterSearch.scala
│ │ │ ├── IrisDataToMulticlassLabeledInstances.scala
│ │ │ └── LearnMulticlassClassifier.scala
│ │ ├── scalding/
│ │ │ ├── ALSJob.scala
│ │ │ ├── FastKNN.scala
│ │ │ ├── LSH.scala
│ │ │ ├── NNMF.scala
│ │ │ ├── SVD.scala
│ │ │ ├── evaluate/
│ │ │ │ ├── GenericCrossValidator.scala
│ │ │ │ └── GenericEvaluator.scala
│ │ │ ├── factorize/
│ │ │ │ └── FactorizationTools.scala
│ │ │ ├── train/
│ │ │ │ ├── AbstractModelTrainer.scala
│ │ │ │ ├── BinaryModelTrainer.scala
│ │ │ │ ├── ClusteringModelTrainer.scala
│ │ │ │ ├── LargeModelTrainer.scala
│ │ │ │ ├── ModelTrainerStrategy.scala
│ │ │ │ ├── MulticlassModelTrainer.scala
│ │ │ │ ├── RegressionModelTrainer.scala
│ │ │ │ └── SmallModelTrainer.scala
│ │ │ └── util/
│ │ │ ├── BaseGridSearcher.scala
│ │ │ ├── DynamicOptions.scala
│ │ │ └── HyperparameterSearcher.scala
│ │ └── text/
│ │ ├── FeatureHelper.scala
│ │ ├── Text.scala
│ │ └── TextSequence.scala
│ └── scalding/
│ └── jobs/
│ └── conjecture/
│ ├── AdHocClassifier.scala
│ ├── AdHocClusterer.scala
│ ├── AdHocMulticlassClassifier.scala
│ ├── AdHocPredictor.scala
│ └── NNMFTest.scala
└── test/
└── java/
└── com/
└── etsy/
└── conjecture/
├── data/
│ ├── LazyVectorTest.java
│ └── StringKeyedVectorTest.java
├── evaluation/
│ └── TestReceiverOperatingCharacteristic.java
└── model/
└── UpdateableLinearModelTest.java
================================================
FILE CONTENTS
================================================
================================================
FILE: .gitignore
================================================
*.class
*.log
*.swp
*.swo
# sbt specific
dist/*
target/
lib_managed/
src_managed/
project/boot/
project/plugins/project/
# Scala-IDE specific
.scala_dependencies
#java
*.class
# Package Files #
*.jar
*.war
*.ear
*~
*\#
.history
.idea
================================================
FILE: .travis.yml
================================================
sudo: false
language: scala
script:
- sbt +test
================================================
FILE: LICENSE.md
================================================
The MIT License
===============
Copyright (c) 2009 Anton Grigoryev
Permission is hereby granted, free of charge, to any person obtaining a copy
of this software and associated documentation files (the "Software"), to deal
in the Software without restriction, including without limitation the rights
to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
copies of the Software, and to permit persons to whom the Software is
furnished to do so, subject to the following conditions:
The above copyright notice and this permission notice shall be included in
all copies or substantial portions of the Software.
THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
THE SOFTWARE.
================================================
FILE: README.md
================================================
# Conjecture [](https://travis-ci.org/etsy/Conjecture)
Conjecture is a framework for building machine learning models in Hadoop using the Scalding DSL.
The goal of this project is to enable the development of statistical models as viable components
in a wide range of product settings. Applications include classification and categorization,
recommender systems, ranking, filtering, and regression (predicting real-valued numbers).
Conjecture has been designed with a primary emphasis on flexibility and can handle a wide variety of inputs.
Integration with Hadoop and scalding enable seamless handling of extremely large data volumes,
and integration with established ETL processes. Predicted labels can either be consumed directly
by the web stack using the dataset loader, or models can be deployed and consumed by live web code.
Currently, binary classification (assigning one of two possible labels to input data points)
is the most mature component of the Conjecture package.
# Tutorial
There are a few stages involved in training a machine learning model using Conjecture.
## Create Training Data
We represent the training data as "feature vectors" which are just mappings of feature names to real values.
In this case we represent them as a java map of strings to doubles
(although we have a class StringKeyedVector which provides convenience methods for feature vector construction).
We also need the true label of each instance, which we represent as 0 and 1
(the mapping of these binary labels to e.g., "male" and "female" is up to the user).
We construct BinaryLabeledInstances, which are just wrappers for a feature vector and a label.
val bl = new BinaryLabeledInstance(0.0)
bl.addTerm("bias", 1.0)
bl.addTerm("some_feature", 0.5)
## Training a Classifier
Classifiers are essentially trained by presenting the labeled instances to them. There are several kinds
of linear classifiers we implement, among them:
* Logistic regression,
* Perceptron,
* MIRA (a large margin perceptron model),
* Passive aggressive.
These models all have several options, such as learning rate, regularization parameters and so on. We supply
reasonable defaults for these parameters although they can be changed readily. To train a linear model
simply call the update function with the labeled instance:
val p = new LogisticRegression()
p.update(bl)
In order to make this procedure tractable for large datasets, we provided scalding wrappers for the training.
These operate by training several small models on mappers, then aggregating them into a final complete model
on the reducers. This wrapper is called like so:
new BinaryModelTrainer(args)
.train(instances, 'instance, 'model)
.write(SequenceFile("model"))
.map('model -> 'model){ x : UpdateableBinaryModel => new com.google.gson.Gson.toJson(x) }
.write(Tsv("model_json"))
This code segment will train a model using a pipe called instances which has a field called instance which contains
the BinaryLabeledInstance objects. It produces a pipe with a single field containing the completed model, which can
then be written to disk.
This class uses the command line args object from scalding, in order to let you set some options on the command line.
Some useful options are:
| Argument | Possible values | Default | Meaning |
|-------------------------------------|-----------------------------------------------|--------------------|--------------------------------------------------|
| --model | mira, logistic_regression, passive_aggressive | passive_aggressive | The type of model to use. |
| --iters | 1, 2, 3... | 1 | The number of iterations of training to perform. |
| --zero_class_prob, --one_class_prob | [0, 1] | 1 | |
To see all the command line options, see the BinaryModelTrainer class.
## Evaluating a Classifier
It is important to get a sense of the performance you can expect out of your classifier on unseen data.
In order to do this we recommend to use cross validation.
In essence, your input set of instances is split up into testing and training portions (multiple different ways),
then a classifier is trained on each training portion, and evaluated (against the true labels which are present)
using the testing portion.
This is all wrapped up in a class called BinaryCrossValidator, it is used like so:
new BinaryCrossValidator(args, 5)
.crossValidate(instances, 'instance)
.write(Tsv("model_xval"))
This class also takes the command line arguments, which it passes to a model trainer for each fold.
This allows the specification of options to the cross validated models on the command line.
The output contains statistics about the performance of the model as well as the confusion matrices
for each fold.
A script is included which cross validates a logistic regression model on the iris dataset.
================================================
FILE: bin/demo.sh
================================================
#!/bin/bash
# - make monolithic conjecture jar.
sbt clean assembly
# - make the instances.
java -cp target/conjecture-assembly-*.jar com.twitter.scalding.Tool com.etsy.conjecture.demo.IrisDataToMulticlassLabeledInstances --input_file data/iris.tsv --output_file iris_model/instances --local
# - construct the classifier.
java -cp target/conjecture-assembly-*.jar com.twitter.scalding.Tool com.etsy.conjecture.demo.LearnMulticlassClassifier --input iris_model/instances --output iris_model --class_names Iris-versicolor,Iris-virginica,Iris-setosa --iters 5 --folds 3 --local
================================================
FILE: bin/model_diff.py
================================================
import json
import sys
import math
if __name__ == '__main__':
if len(sys.argv) != 3:
sys.exit("Usage: python " + sys.argv[0] + " [model file] [model file]")
a = json.load(open(sys.argv[1]))['param']['vector']
b = json.load(open(sys.argv[2]))['param']['vector']
features = set(a.keys()) | set(b.keys())
diff = []
for f in features:
dv = a.get(f, 0.0) - b.get(f, 0.0)
if math.fabs(dv) > 0.01:
diff.append((f, dv, a.get(f), b.get(f)))
diff.sort( key = lambda tup: -math.fabs(tup[1]))
for t in diff:
print t[0] + "\t" + str(t[2]) + "\t" + str(t[3]) + "\t(" + str(t[1]) + ")"
================================================
FILE: bin/model_param.py
================================================
import json
import sys
import math
if __name__ == '__main__':
if len(sys.argv) != 2:
sys.exit("Usage: python " + sys.argv[0] + " [model file]")
vec = json.load(open(sys.argv[1]))['param']['vector'].items()
vec.sort(key = lambda tup: -math.fabs(tup[1]))
for v in vec:
print v[0] + "\t" + str(v[1])
================================================
FILE: bin/prediction_inspection.py
================================================
import json
import sys
from optparse import OptionParser
from math import floor
colors = ["FF0000", "FF1000", "FF2000", "FF3000", "FF4000", "FF5000", "FF6000",
"FF7000", "FF8000", "FF9000", "FFA000", "FFB000", "FFC000", "FFD000",
"FFE000", "FFF000", "FFFF00", "F0FF00", "E0FF00", "D0FF00", "C0FF00",
"B0FF00", "A0FF00", "90FF00", "80FF00", "70FF00", "60FF00", "50FF00",
"40FF00", "30FF00", "20FF00", "10FF00"]
bins = len(colors)
parser = OptionParser(usage="""builds a simple web page providing introspection on predictions made by conjecture models.
Depends on the supporting data provided in the instance itself, currently only supporting binary
classification problems
Usage: %prog [options]
""")
parser.add_option('-o', '--out', dest='out', default=False, action='store',
help="[optional] destination of the generated html. Defaults to standard out")
parser.add_option('-f', '--file', dest='file', default=False, action='store',
help="[optional] file storing input predictions and instances. Defaults to standard in")
parser.add_option('-l', '--label', dest='label', default=False, action='store',
help="[optional] only keep examples with this label")
parser.add_option('-L', '--limit', dest='limit', default=1000, action='store',
help="maximum number of prediction examples to display. Default: 1000")
(options, args) = parser.parse_args()
output = open(options.out, 'w') if (options.out) else sys.stdout
input = open(options.file, 'r') if(options.file) else sys.stdin
limit = int(options.limit)
output.write("")
ct = 0
for line in input:
parts = line.strip().split("\t")
content = json.loads(parts[0])
label = int(content['label']['value'])
pred = float(parts[2])
if (options.label and str(label) != options.label):
continue
error = min(1.0, abs(pred-label))
bin = bins - int(floor(error*bins)) - 1
color = "#" + colors[bin]
out = ""
support = json.loads(content['supporting_data'])
for key in support.keys():
out = out + "" + key + " " + support[key] + " "
if (len(out) < 10000 and ct < limit):
try:
output.write("
");
output.write("%d (%f) " %( label, pred))
output.write(out)
output.write("
")
ct = ct + 1
except:
pass
if (ct >= limit):
break
output.write("");
output.flush()
output.close()
================================================
FILE: build.sbt
================================================
import sbt._
name := "conjecture"
version := "0.3.1-SNAPSHOT"
organization := "com.etsy"
scalaVersion := "2.11.11"
crossScalaVersions := Seq("2.11.11", "2.12.4")
scalacOptions ++= Seq("-unchecked", "-deprecation")
//Because some of our (legal!) java code confuses scaladoc, we must skip it for 2.12
//See: https://github.com/scala/bug/issues/10723
scalacOptions in (Compile, doc) += {if(scalaBinaryVersion.value == "2.12") "-no-java-comments" else ""}
javacOptions ++= Seq("-Xlint:none", "-source", "1.7", "-target", "1.7")
compileOrder := CompileOrder.JavaThenScala
resolvers ++= {
Seq(
"Concurrent Maven Repo" at "http://conjars.org/repo"
)
}
libraryDependencies ++= Seq(
"cascading" % "cascading-core" % "2.6.1",
"cascading" % "cascading-local" % "2.6.1" exclude("com.google.guava", "guava"),
"cascading" % "cascading-hadoop" % "2.6.1",
"com.google.code.gson" % "gson" % "2.2.2",
"com.twitter" %% "algebird-core" % "0.13.0" excludeAll ExclusionRule(organization="org.scala-lang", name="scala-library"),
"com.twitter" %% "scalding-core" % "0.17.4" excludeAll ExclusionRule(organization="org.scala-lang", name="scala-library"),
"commons-lang" % "commons-lang" % "2.4",
"com.joestelmach" % "natty" % "0.7",
"io.spray" %% "spray-json" % "1.3.2" excludeAll ExclusionRule(organization="org.scala-lang", name="scala-library"),
"com.google.guava" % "guava" % "13.0.1",
"org.apache.commons" % "commons-math3" % "3.2",
"org.apache.hadoop" % "hadoop-common" % "2.5.0" excludeAll(
ExclusionRule(organization="commons-daemon", name="commons-daemon"),
ExclusionRule(organization="com.google.guava", name="guava")
),
"org.apache.hadoop" % "hadoop-hdfs" % "2.5.0" excludeAll(
ExclusionRule(organization="commons-daemon", name="commons-daemon"),
ExclusionRule(organization="com.google.guava", name="guava")
),
"org.scala-lang" % "scala-reflect" % scalaVersion.value,
"net.sf.trove4j" % "trove4j" % "3.0.3",
"com.novocode" % "junit-interface" % "0.10" % "test"
)
parallelExecution in Test := false
publishArtifact in Test := false
xerial.sbt.Sonatype.sonatypeSettings
publishTo := {
if (System.getProperty("release") != null) {
publishTo.value
} else {
val v = version.value
val archivaURL = "http://ivy.etsycorp.com/repository"
if (v.trim.endsWith("SNAPSHOT")) {
Some("publish-snapshots" at (archivaURL + "/snapshots"))
} else {
Some("publish-releases" at (archivaURL + "/internal"))
}
}
}
publishMavenStyle := true
overridePublishBothSettings
pomIncludeRepository := { x => false }
pomExtra := https://github.com/etsy/Conjecture
MIT License
http://opensource.org/licenses/MIT
repo
git@github.com:etsy/Conjecture.git
scm:git:git@github.com:etsy/Conjecture.git
jattenberg
Josh Attenberg
github.com/jattenberg
rjhall
Rob Hall
github.com/rjhall
pomIncludeRepository := { _ => false }
// Uncomment if you don't want to run all the tests before building assembly
// test in assembly := {}
// Janino includes a broken signature, and is not needed:
assemblyExcludedJars in assembly <<= (fullClasspath in assembly) map { cp =>
val excludes = Set("jsp-api-2.1-6.1.14.jar", "jsp-2.1-6.1.14.jar",
"jasper-compiler-5.5.12.jar", "janino-2.5.16.jar")
cp filter { jar => excludes(jar.data.getName)}
}
// Some of these files have duplicates, let's ignore:
assemblyMergeStrategy in assembly <<= (mergeStrategy in assembly) { (old) =>
{
case s if s.endsWith(".class") => MergeStrategy.last
case s if s.endsWith("project.clj") => MergeStrategy.concat
case s if s.endsWith(".html") => MergeStrategy.last
case s if s.contains("servlet") => MergeStrategy.last
case x => old(x)
}
}
================================================
FILE: clients/phplib/Conjecture/BinaryClassifier.php
================================================
param = $param_vec;
}
public function dot($instance_vec) {
return $this->param->dot($instance_vec);
}
public function predict($instance_vec) {
$dot = $this->dot($instance_vec);
$exd = exp($dot);
return $exd / (1.0 + $exd);
}
public function getParams() {
return $this->param->getParams();
}
public function explain($instance_vec, $n = 10) {
$keys = array_intersect_key($this->param->getParams(), $instance_vec->getParams());
$keys = array_map('abs', $keys);
arsort($keys);
$res = array_slice($keys, 0, (count($keys) < $n ? count($keys) : $n));
foreach ($res as $k => $v) {
$res[$k] = "$k(" . round($this->param->getParam($k), 2) . ")";
}
return implode(", ", $res);
}
}
================================================
FILE: clients/phplib/Conjecture/Config.php
================================================
config = $config;
}
/**
* Loads a model local to a user's vm.
*/
public function getLocalModel($local_file_path) {
$model = json_decode($this->parseFile($local_file_path));
$cv = new Conjecture_Vector($model->param->vector);
$binary_classifier = new Conjecture_BinaryClassifier($cv);
return $binary_classifier;
}
/**
* Decode model json at a given filepath.
*/
private function parseFile($fp) {
if (filesize($fp) > $this->config->getMaxFileSize()) {
throw new Conjecture_ConjectureException("model too big: " . $fp . " is " . filesize($fp) . "bytes");
}
$res = file($fp);
if ($res) {
$res = implode("", $res);
$res = stripslashes($res);
return $res;
} else {
throw new Conjecture_ConjectureException("model file not found: $fp");
}
}
private function getLatestModelJsonForProblem($file_name) {
if ($this->config->useDummyConjectureModel()) {
return self::getDummyModel();
}
$fp = $this->config->getConjectureModelPath() . "/" . $file_name;
return $this->parseFile($fp);
}
public function getLatestModelForProblem($file_name) {
$json = $this->getLatestModelJsonForProblem($file_name);
return json_decode($json);
}
public function getLatestBinaryClassificationVectorForProblem($file_name) {
$model = $this->getLatestModelForProblem($file_name);
return new Conjecture_Vector($model->param->vector);
}
public function getLatestBinaryClassifierForProblem($file_name) {
return new Conjecture_BinaryClassifier($this->getLatestBinaryClassificationVectorForProblem($file_name));
}
public function getOneVsAllClassifier($file_name) {
$model_array = $this->getLatestModelForProblem($file_name);
foreach ($model_array as $cat => $params) {
$category_params[$cat] = new Conjecture_BinaryClassifier(new Conjecture_Vector($params));
}
return new Conjecture_MulticlassOneVsAllClassifier($category_params);
}
public function getMulticlassClassifier($file_name) {
$model_array = $this->getLatestModelForProblem($file_name);
$model_type = $model_array->modelType;
$category_params = [];
foreach ($model_array->param as $cat => $category_model) {
$category_params[$cat] = new Conjecture_Vector($category_model->vector);
}
switch ($model_type) {
case "multiclass_logistic_regression":
return new Conjecture_MulticlassLogisticRegressionClassifier($category_params);
default:
return new Conjecture_MulticlassClassifier($category_params);
}
}
static function build(Conjecture_Config $config) {
return new Conjecture_Finder($config);
}
/**
* Creates and returns a JSON dummy model with no vectors
* used for development settings where "real" JSON models
* may not be present
*/
private static function getDummyModel() {
$dummy_model = array("param" => array(
"vector" => array(),
"modelType" => "dummy",
"regularizationWeights" => array(),
"epoch" => 1,
"period" => 1,
"truncationUpdate" => 0,
"truncationThreshold" => 0,
"initialLearningRate" => .1,
"useExponentialLearningRate" => false,
"exponentialLearningRate" => 1.0,
"examplesPerEpoch" => 1,
));
return json_encode($dummy_model);
}
}
================================================
FILE: clients/phplib/Conjecture/Instance.php
================================================
id;
}
public function setId($id) {
$this->id = $id;
return $this;
}
public function put($key, $value = 1.0) {
$this->vector[$key] = $value;
}
public function update($key, $value = 1.0) {
if (array_key_exists($key, $this->vector)) {
$this->vector[$key] = $this->vector[$key] + $value;
} else {
$this->vector[$key] = $value;
}
return $this;
}
//some methods to mirror java maps that this class mirrors
public function putAll(array $vector) {
foreach ($vector as $key => $value) {
$this->put($key, $value);
}
}
public function containsKey($key) {
return array_key_exists($key, $this->vector);
}
public function containsValue($key) {
return in_array($key, $this->vector);
}
public function keySet() {
return array_keys($this->vector);
}
public function values() {
return array_values($this->vector);
}
public function size() {
return count($this->vector);
}
public function isEmpty() {
return empty($this->vector);
}
public function remove($key) {
unset($this->vector[$key]);
}
public function toString() {
return json_encode($this->vector);
}
public function addTerm($term, $featureWeight = 1.0, $namespace = "") {
$key = $namespace == "" ? $term : $namespace . self::$NAMESPACE_SEP . $term;
$this->update($key, $featureWeight);
return $this;
}
public function addTerms(array $terms, $featureWeight = 1.0, $namespace = "") {
foreach ($terms as $term) {
$this->addTerm($term, $featureWeight, $namespace);
}
return $this;
}
public function addNumericArray(array $numberValues, $namespace = "") {
for ($i = 0; $i < count($numberValues); $i++ ) {
$this->addTerm((string)$i, $numberValues[$i], $namespace);
}
return $this;
}
}
================================================
FILE: clients/phplib/Conjecture/MulticlassClassifier.php
================================================
param = $param;
}
public function predict($instance_vec) {
$category_results = [];
$total = 0;
foreach ($this->param as $category => $classifier) {
$prediction = $classifier->dot($instance_vec);
$category_results[$category] = $prediction;
$total += $prediction;
}
return array_map( function($prob) use ($total) {
return $prob / $total;
}, $category_results);
}
public function getParams() {
return $this->param;
}
public function explain($instance_vec, $n = 10) {
$explains = [];
foreach ($this->param as $category => $category_model) {
$explains[$category] = $this->categoryExplain($instance_vec, $category_model, $n);
}
return implode(", ", $explains);
}
private function categoryExplain($instance_vec, $category_model, $n = 10) {
$keys = array_intersect_key($category_model->getParams(), $instance_vec->getParams());
$keys = array_map('abs', $keys);
arsort($keys);
$res = array_slice($keys, 0, (count($keys) < $n ? count($keys) : $n));
foreach ($res as $k => $v) {
$res[$k] = "$k(" . round($category_model->getParams($k), 2) . ")";
}
return implode(", ", $res);
}
}
================================================
FILE: clients/phplib/Conjecture/MulticlassLogisticRegressionClassifier.php
================================================
param as $category => $classifier) {
$prediction = exp($classifier->dot($instance_vec));
$category_results[$category] = $prediction;
$total += $prediction;
}
return array_map( function($prob) use ($total) {
return $prob / $total;
}, $category_results);
}
}
================================================
FILE: clients/phplib/Conjecture/MulticlassOneVsAllClassifier.php
================================================
param = $param;
}
public function predict($instance_vec) {
$category_results = [];
$total = 0;
foreach ($this->param as $category => $classifier) {
$prediction = $classifier->predict($instance_vec);
$category_results[$category] = $prediction;
$total += $prediction;
}
return array_map( function($prob) use ($total) {
return $prob / $total;
}, $category_results);
}
public function getParams() {
$out_params = [];
foreach ($this->param as $category => $classifier) {
$out_params[$category] = $classifier->getParams();
}
return $out_params;
}
public function explain($instance_vec, $n = 10) {
$explains = [];
foreach ($this->param as $category => $classifier) {
$explains[$category] = $classifier->explain($instance_vec, $n);
}
return implode(", ", $explains);
}
}
================================================
FILE: clients/phplib/Conjecture/Text.php
================================================
input = $text;
}
function toString() {
return $this->input;
}
function replaceNumbers($replacement = "_num_") {
$text = preg_replace("/[0-9]+/", $replacement, $this->input);
return new Conjecture_Text(preg_replace("/".$replacement."\\s+".$replacement."/", $replacement, $text));
}
function replaceHTMLEscapes($replacement = " ") {
return new Conjecture_Text(preg_replace("/&[^;]+;/", $replacement, $this->input));
}
function removeHTMLTags() {
return new Conjecture_Text(preg_replace("/<.*?>/", " ", $this->input));
}
function replaceHTMLTags($replacement = " ") {
return new Conjecture_Text(preg_replace("/<[^>]+>/", " ", $this->input));
}
function replaceNonAlphaNumeric($replacement = " ") {
return new Conjecture_Text(preg_replace("/[^a-zA-Z0-9\\.\\s\\-]+/", $replacement, $this->input));
}
function replaceNonAlphaNumericUnderscore($replacement = " ") {
return new Conjecture_Text(preg_replace("/[^a-zA-Z0-9\\.\\s\\-_]+/", $replacement, $this->input));
}
function replaceNonAlpha($replacement = " ") {
return new Conjecture_Text(preg_replace("/[^a-zA-Z]+/", $replacement, $this->input));
}
function collapseHyphens() {
return new Conjecture_Text(preg_replace("/--+/", "--", $this->input));
}
function collapseUnderscores() {
return new Conjecture_Text(preg_replace("/__+/", "__", $this->input));
}
function collapsePeriods() {
return new Conjecture_Text(preg_replace("/\.\.+/", "..", $this->input));
}
function stripPunctuation() {
$temp = preg_replace("^[^A-Za-z0-9]", "", $this->input);
return new Conjecture_Text(preg_replace("[^A-Za-z0-9]$", "", $temp));
}
// compact any white space
function collapse() {
return new Conjecture_Text(preg_replace("/\\s+/", " ", $this->input));
}
// remove any whitespace from the right of a string
function rstrip() {
return new Conjecture_Text(preg_replace("/\\s+$/", "", $this->input));
}
// remove any whitespace from the left of a string
function lstrip() {
return new Conjecture_Text(preg_replace("/^\\s+/", "", $this->input));
}
// remove any leading or trailing whitespace
function strip() {
return $this->rstrip()->lstrip();
}
// clean up any whitespace
function wsclean() {
return $this->strip()->collapse();
}
// remove any unprintable non-ASCII characters
function removeUnprintables() {
return new Conjecture_Text(preg_replace("/[^\\x20-\\x7E]/", "", $this->input));
}
function collapseWhitespaceAndPunc() {
$text = $this->collapse()->collapseHyphens();
return new Conjecture_Text(preg_replace("/\\.\\.+/", ".", $text->toString()));
}
function toLowerCase() {
return new Conjecture_Text(strtolower($this->input));
}
function standardTextFilter() {
return $this->removeHTMLTags()
->replaceHTMLEscapes()
->replaceNumbers()
->replaceNonAlphaNumericUnderscore()
->collapseHyphens()
->collapseUnderscores()
->wsclean();
}
function toArrayFromShingles($n) {
$shingles = array();
$chars = str_split($this->input);
for ($i = 0; $i < count($chars) - $n + 1; $i++) {
$shingle = array_slice($chars, $i, $n);
$shingles[] = implode("", $shingle);
}
return $shingles;
}
function toSequenceFromShingles($n) {
return new Conjecture_TextSequence($this->toArrayFromShingles($n));
}
}
================================================
FILE: clients/phplib/Conjecture/TextSequence.php
================================================
tokens = $tokens;
}
/**
* concatenates two TextSequences into an additional text sequence
*/
function concat($other) {
return new Conjecture_TextSequence(array_merge($this->tokens, $other->tokens));
}
function mkString($glue = " ") {
return implode($glue, $this->tokens);
}
function toString() {
return $this->mkString(" ");
}
function getTokens() {
return $this->tokens;
}
function filterBlank() {
return new Conjecture_TextSequence(array_filter($this->tokens,
function($x) {
return $x !== "";
}
)
);
}
function filterStopwords() {
return new Conjecture_TextSequence(array_filter($this->tokens,
function($x) {
return !in_array($x, self::$stopwordList);
}
)
);
}
function stopwords() {
return new Conjecture_TextSequence(array_filter($this->tokens,
function($x) {
return in_array($x, self::$stopwordList);
}
)
);
}
function filterBadwords() {
return new Conjecture_TextSequence(array_filter($this->tokens,
function($x) {
return !in_array($x, self::$badwordList);
}
)
);
}
function badwords() {
return new Conjecture_TextSequence(array_filter($this->tokens,
function($x) {
return in_array($x, self::$badwordList);
}
)
);
}
function filterAllCaps() {
return new Conjecture_TextSequence(array_filter($this->tokens,
function($x) {
return !preg_match('/^[A-Z]+$/', $x);
}
)
);
}
function AllCaps() {
return new Conjecture_TextSequence(array_filter($this->tokens,
function($x) {
return preg_match('/^[A-Z]+$/', $x);
}
)
);
}
function filterCapitalized() {
return new Conjecture_TextSequence(array_filter($this->tokens,
function($x) {
return !preg_match('/^[A-Z][^A-Z]+$/', $x);
}
)
);
}
function capitalized() {
return new Conjecture_TextSequence(array_filter($this->tokens,
function($x) {
return preg_match('/^[A-Z][^A-Z]+$/', $x);
}
)
);
}
function filterLowercase() {
return new Conjecture_TextSequence(array_filter($this->tokens,
function($x) {
return !preg_match('/^[a-z]+$/', $x);
}
)
);
}
function allLowercase() {
return new Conjecture_TextSequence(array_filter($this->tokens,
function($x) {
return preg_match('/^[a-z]+$/', $x);
}
)
);
}
function filterURLs() {
return new Conjecture_TextSequence(array_filter($this->tokens,
function($x) {
return !preg_match('/^https?://.+/', $x);
}
)
);
}
function allURLs() {
return new Conjecture_TextSequence(array_filter($this->tokens,
function($x) {
return preg_match('/^https?://.+/', $x);
}
)
);
}
function filterListings() {
return new Conjecture_TextSequence(array_filter($this->tokens,
function($x) {
return !preg_match('/^https?://.+etsy.+/listing/[0-9]+.*/', $x);
}
)
);
}
function allListings() {
return new Conjecture_TextSequence(array_filter($this->tokens,
function($x) {
return preg_match('/^https?://.+etsy.+/listing/[0-9]+.*/', $x);
}
)
);
}
function size() {
return count($this->tokens);
}
function stopWordCount() {
return $this->stopwords()->size();
}
function stopWordFraq($bins = 10.0) {
return floor(round($bins*$this->stopWordCount()/$this->size())/$bins);
}
function badWordCount() {
return $this->badwords()->size();
}
function badWordFraq($bins = 10.0) {
return floor(round($bins*$this->badWordCount()/$this->size())/$bins);
}
function capsCount() {
return $this->allCaps()->size();
}
function capFraq($bins = 10.0) {
return floor(round($bins*$this->capsCount()/$this->size())/$bins);
}
function urlCount() {
return $this->allURLs()->size();
}
function urlFraq($bins = 10.0) {
return floor(round($bins*$this->urlCount()/$this->size())/$bins);
}
function listingsCount() {
return $this->badwords()->size();
}
function listingsFraq($bins = 10.0) {
return floor(round($bins*$this->allListings()/$this->size())/$bins);
}
function sizeBin() {
return floor(log($this->size()));
}
// filtering methods (TODO)
function replaceNumbers($replacement = "_num_") {
return new Conjecture_TextSequence(array_map(
function($x) use ($replacement) {
$text = preg_replace("/[0-9]+/", $replacement, $x);
return preg_replace("/".$replacement."\\s+".$replacement."/", $replacement, $text);
}, $this->tokens));
}
function replaceHTMLEscapes($replacement = " ") {
return new Conjecture_TextSequence(array_map(
function($x) use ($replacement) {
return preg_replace("/&[^;]+;/", $replacement, $x);
}, $this->tokens));
}
function removeHTMLTags() {
return $this->replaceHTMLTags(" ");
}
function replaceHTMLTags($replacement = " ") {
return new Conjecture_TextSequence(array_map(
function($x) use ($replacement) {
return preg_replace("/<[^>]+>/", $replacement, $x);
}, $this->tokens));
}
function replaceNonAlphaNumeric($replacement = " ") {
return new Conjecture_TextSequence(array_map(
function($x) use ($replacement) {
return preg_replace("/[^a-zA-Z0-9\\.\\s\\-]+/", $replacement, $x);
}, $this->tokens));
}
function replaceNonAlphaNumericUnderscore($replacement = " ") {
return new Conjecture_TextSequence(array_map(
function($x) use ($replacement) {
return preg_replace("/[^a-zA-Z0-9\\.\\s\\-_]+/", $replacement, $x);
}, $this->tokens));
}
function replaceNonAlpha($replacement = " ") {
return new Conjecture_TextSequence(array_map(
function($x) use ($replacement) {
return preg_replace("/[^a-zA-Z\\.\\s\\-_]+/", $replacement, $x);
}, $this->tokens));
}
function collapseHyphens() {
return new Conjecture_TextSequence(array_map(
function($x) {
return preg_replace("/--+/", "--", $x);
}, $this->tokens));
}
function collapseUnderscores() {
return new Conjecture_TextSequence(array_map(
function($x) {
return preg_replace("/__+/", "__", $x);
}, $this->tokens));
}
function collapsePeriods() {
return new Conjecture_TextSequence(array_map(
function($x) {
return preg_replace("/\.\.+/", "..", $x);
}, $this->tokens));
}
function stripPunctuation() {
return new Conjecture_TextSequence(array_map(
function($x) {
$temp = preg_replace("^[^A-Za-z0-9]", "", $x);
return preg_replace("[^A-Za-z0-9]$", "", $temp);
}, $this->tokens));
}
// compact any white space
function collapse() {
return new Conjecture_TextSequence(array_map(
function($x) {
return preg_replace("/\\s+/", " ", $x);
}, $this->tokens));
}
function rstrip() {
return new Conjecture_TextSequence(array_map(
function($x) {
return preg_replace("/^\\s+/", "", $x);
}, $this->tokens));
}
function lstrip() {
return new Conjecture_TextSequence(array_map(
function($x) {
return preg_replace("/\\s+$/", "", $x);
}, $this->tokens));
}
// remove any leading or trailing whitespace
function strip() {
return $this->rstrip()->lstrip();
}
// clean up any whitespace
function wsclean() {
return $this->strip()->collapse();
}
// remove any unprintable non-ASCII characters
function removeUnprintables() {
return new Conjecture_TextSequence(array_map(
function($x) {
return preg_replace("/[^\\x20-\\x7E]/", "", $x);
}, $this->tokens));
}
function collapseWhitespaceAndPunc() {
return new Conjecture_TextSequence(array_map(
function($x) {
$ws = preg_replace("/\\s+/", " ", $x);
$dh = preg_replace("/[\\-]+/", "-", $ws);
return preg_replace("/[\\.]+/", ".", $dh);
}, $this->tokens));
}
function prependNameSpace($namespace) {
return new Conjecture_TextSequence(array_map(
function($x) use ($namespace) {
return $namespace . $x;
}, $this->tokens));
}
function toList() {
return $this->tokens;
}
function shingles($n, $whitespace = "_") {
$str = implode($whitespace, $this->tokens);
$arr = explode('', $str);
$shingles = array();
for ($i = 0; $i < count($arr) - $n; $i++) {
$shingles[] = implode('', array_slice($arr, $i, $i + $n));
}
return new Conjecture_TextSequence($shingles);
}
function ngrams($n, $glue = " ") {
$grams = array();
for ($i = 0; $i < count($this->tokens) - $n+1; $i++) {
$grams[] = implode($glue, array_slice($this->tokens, $i, $n));
}
return new Conjecture_TextSequence($grams);
}
function unigramsAndBigrams($glue = " ") {
return $this->ngrams(1)->concat($this->ngrams(2, $glue));
}
function toInstance() {
$instance = new Conjecture_Instance();
foreach ($this->tokens as $token) {
$instance->addTerm($token);
}
return $instance;
}
static $stopwordList = array("a","as","able","about","above","according","accordingly","across","actually","after","afterwards","again","against","aint","all","allow","allows","almost","alone","along","already","also","although","always","am","among","amongst","amoungst","amount","an","and","another","any","anybody","anyhow","anyone","anything","anyway","anyways","anywhere","apart","appear","appreciate","appropriate","are","arent","around","as","aside","ask","asking","associated","at","available","away","awfully","b","back","be","became","because","become","becomes","becoming","been","before","beforehand","behind","being","believe","below","beside","besides","best","better","between","beyond","bill","both","bottom","brief","but","by","c","cmon","cs","call","came","can","cant","cannot","cant","cause","causes","certain","certainly","changes","clearly","co","com","come","comes","con","concerning","consequently","consider","considering","contain","containing","contains","corresponding","could","couldnt","couldnt","course","cry","currently","d","de","definitely","describe","described","despite","detail","did","didnt","different","do","does","doesnt","doing","dont","done","down","downwards","due","during","e","each","edu","eg","eight","either","eleven","else","elsewhere","empty","enough","entirely","especially","et","etc","even","ever","every","everybody","everyone","everything","everywhere","ex","exactly","example","except","f","far","few","fifteen","fifth","fify","fill","find","fire","first","five","followed","following","follows","for","former","formerly","forth","forty","found","four","from","front","full","further","furthermore","g","get","gets","getting","give","given","gives","go","goes","going","gone","got","gotten","greetings","h","had","hadnt","happens","hardly","has","hasnt","hasnt","have","havent","having","he","hes","hello","help","hence","her","here","heres","hereafter","hereby","herein","hereupon","hers","herself","hi","him","himself","his","hither","hopefully","how","howbeit","however","hundred","i","id","ill","im","ive","ie","if","ignored","immediate","in","inasmuch","inc","indeed","indicate","indicated","indicates","inner","insofar","instead","interest","into","inward","is","isnt","it","itd","itll","its","its","itself","j","just","k","keep","keeps","kept","know","known","knows","l","last","lately","later","latter","latterly","least","less","lest","let","lets","like","liked","likely","little","look","looking","looks","ltd","m","made","mainly","many","may","maybe","me","mean","meanwhile","merely","might","mill","mine","more","moreover","most","mostly","move","much","must","my","myself","n","name","namely","nd","near","nearly","necessary","need","needs","neither","never","nevertheless","new","next","nine","no","nobody","non","none","noone","nor","normally","not","nothing","novel","now","nowhere","o","obviously","of","off","often","oh","ok","okay","old","on","once","one","ones","only","onto","or","other","others","otherwise","ought","our","ours","ourselves","out","outside","over","overall","own","p","part","particular","particularly","per","perhaps","placed","please","plus","possible","presumably","probably","provides","put","q","que","quite","qv","r","rather","rd","re","really","reasonably","regarding","regardless","regards","relatively","respectively","right","s","said","same","saw","say","saying","says","second","secondly","see","seeing","seem","seemed","seeming","seems","seen","self","selves","sensible","sent","serious","seriously","seven","several","shall","she","should","shouldnt","show","side","since","sincere","six","sixty","so","some","somebody","somehow","someone","something","sometime","sometimes","somewhat","somewhere","soon","sorry","specified","specify","specifying","still","sub","such","sup","sure","system","t","ts","take","taken","tell","ten","tends","th","than","thank","thanks","thanx","that","thats","thats","the","thea","their","theirs","them","themselves","then","thence","there","theres","thereafter","thereby","therefore","therein","theres","thereupon","these","they","theyd","theyll","theyre","theyve","thickv","thin","think","third","this","thorough","thoroughly","those","though","three","through","throughout","thru","thus","to","together","too","took","top","toward","towards","tried","tries","truly","try","trying","twelve","twenty","twice","two","u","un","under","unfortunately","unless","unlikely","until","unto","up","re","werent","what","whats","whatever","when","whence","whenever","where","wheres","whereafter","whereas","whereby","wherein","whereupon","wherever","whether","which","while","whither","who","whos","whoever","whole","whom","whose","why","will","willing","wish","with","within","without","wont","wonder","would","wouldnt","x","y","yes","yet","you","youd","youll","youre","youve","your","yours","yourself","yourselves","z","zero");
static $badwordList = array("ahole", "arse", "ass", "asshole", "asswipe", "bastard", "batty", "bender", "bitch", "bloody", "bollocks", "boner", "bumboy", "bugger", "coon", "cock", "cocksucker", "cracker", "crap", "cumsucker", "cunt", "damn", "dick", "dildo", "douchebag", "faggot", "fistfucker", "fuck", "fucker", "fuckwit", "fucktwat", "gaylord", "ho", "honky", "jackass", "jism", "joey", "knobcheese", "minge", "minger", "mong", "motherfucker", "munter", "pickle", "piss", "piss", "prick", "pussy", "rimmer", "schmuck", "shit", "slut", "spakka", "spaz", "skank", "taint", "tit", "tool", "tosser", "twat", "whore", "wanker");
}
================================================
FILE: clients/phplib/Conjecture/Vector.php
================================================
vector = (array)$array;
}
public function dot($rhs) {
$keys = array_intersect_key($this->vector, $rhs->vector);
$res = 0.0;
foreach ($keys as $key => $val) {
$res += $this->vector[$key] * $rhs->vector[$key];
}
return $res;
}
public function getParams() {
return $this->vector;
}
public function getParam($k) {
if (array_key_exists($k, $this->vector)) {
return $this->vector[$k];
} else {
return 0.0;
}
}
}
================================================
FILE: data/iris.tsv
================================================
7.0 3.2 4.7 1.4 Iris-versicolor
5.6 3.0 4.1 1.3 Iris-versicolor
5.4 3.4 1.7 0.2 Iris-setosa
5.0 3.0 1.6 0.2 Iris-setosa
6.9 3.2 5.7 2.3 Iris-virginica
4.9 3.0 1.4 0.2 Iris-setosa
5.0 2.3 3.3 1.0 Iris-versicolor
5.2 2.7 3.9 1.4 Iris-versicolor
5.1 3.8 1.9 0.4 Iris-setosa
7.2 3.6 6.1 2.5 Iris-virginica
4.8 3.4 1.6 0.2 Iris-setosa
6.0 2.9 4.5 1.5 Iris-versicolor
5.8 2.6 4.0 1.2 Iris-versicolor
5.7 2.6 3.5 1.0 Iris-versicolor
5.9 3.0 4.2 1.5 Iris-versicolor
5.5 2.3 4.0 1.3 Iris-versicolor
4.6 3.2 1.4 0.2 Iris-setosa
6.3 2.8 5.1 1.5 Iris-virginica
6.3 3.3 6.0 2.5 Iris-virginica
6.9 3.1 4.9 1.5 Iris-versicolor
6.7 3.3 5.7 2.5 Iris-virginica
5.1 3.7 1.5 0.4 Iris-setosa
6.7 3.3 5.7 2.1 Iris-virginica
5.8 2.8 5.1 2.4 Iris-virginica
6.0 3.4 4.5 1.6 Iris-versicolor
5.4 3.0 4.5 1.5 Iris-versicolor
5.5 3.5 1.3 0.2 Iris-setosa
5.0 3.3 1.4 0.2 Iris-setosa
5.7 4.4 1.5 0.4 Iris-setosa
5.3 3.7 1.5 0.2 Iris-setosa
5.2 3.5 1.5 0.2 Iris-setosa
6.5 2.8 4.6 1.5 Iris-versicolor
7.4 2.8 6.1 1.9 Iris-virginica
4.9 3.1 1.5 0.2 Iris-setosa
5.0 3.2 1.2 0.2 Iris-setosa
7.7 2.8 6.7 2.0 Iris-virginica
4.8 3.4 1.9 0.2 Iris-setosa
6.5 3.0 5.2 2.0 Iris-virginica
6.3 2.5 5.0 1.9 Iris-virginica
6.4 3.1 5.5 1.8 Iris-virginica
5.8 2.7 5.1 1.9 Iris-virginica
7.1 3.0 5.9 2.1 Iris-virginica
5.7 2.5 5.0 2.0 Iris-virginica
6.4 2.8 5.6 2.2 Iris-virginica
6.4 3.2 4.5 1.5 Iris-versicolor
6.1 2.6 5.6 1.4 Iris-virginica
4.8 3.0 1.4 0.1 Iris-setosa
5.6 2.8 4.9 2.0 Iris-virginica
6.0 2.2 5.0 1.5 Iris-virginica
5.0 3.5 1.3 0.3 Iris-setosa
5.5 2.6 4.4 1.2 Iris-versicolor
5.0 3.6 1.4 0.2 Iris-setosa
5.0 3.4 1.6 0.4 Iris-setosa
6.3 2.7 4.9 1.8 Iris-virginica
6.7 3.1 4.7 1.5 Iris-versicolor
6.3 2.5 4.9 1.5 Iris-versicolor
4.5 2.3 1.3 0.3 Iris-setosa
6.8 3.2 5.9 2.3 Iris-virginica
7.2 3.2 6.0 1.8 Iris-virginica
5.5 2.4 3.8 1.1 Iris-versicolor
5.8 2.7 5.1 1.9 Iris-virginica
6.1 2.8 4.0 1.3 Iris-versicolor
6.3 2.9 5.6 1.8 Iris-virginica
6.1 2.9 4.7 1.4 Iris-versicolor
6.3 2.3 4.4 1.3 Iris-versicolor
4.6 3.4 1.4 0.3 Iris-setosa
5.5 4.2 1.4 0.2 Iris-setosa
6.5 3.0 5.5 1.8 Iris-virginica
6.7 3.1 4.4 1.4 Iris-versicolor
6.6 2.9 4.6 1.3 Iris-versicolor
5.9 3.0 5.1 1.8 Iris-virginica
6.4 2.7 5.3 1.9 Iris-virginica
5.6 2.5 3.9 1.1 Iris-versicolor
6.4 3.2 5.3 2.3 Iris-virginica
5.7 3.8 1.7 0.3 Iris-setosa
7.2 3.0 5.8 1.6 Iris-virginica
6.7 3.0 5.2 2.3 Iris-virginica
4.6 3.1 1.5 0.2 Iris-setosa
5.6 2.9 3.6 1.3 Iris-versicolor
6.4 2.9 4.3 1.3 Iris-versicolor
5.1 3.5 1.4 0.2 Iris-setosa
7.6 3.0 6.6 2.1 Iris-virginica
5.7 2.8 4.1 1.3 Iris-versicolor
5.6 2.7 4.2 1.3 Iris-versicolor
5.7 2.9 4.2 1.3 Iris-versicolor
5.4 3.7 1.5 0.2 Iris-setosa
6.4 2.8 5.6 2.1 Iris-virginica
4.6 3.6 1.0 0.2 Iris-setosa
4.4 2.9 1.4 0.2 Iris-setosa
4.4 3.2 1.3 0.2 Iris-setosa
6.2 3.4 5.4 2.3 Iris-virginica
6.3 3.4 5.6 2.4 Iris-virginica
6.8 2.8 4.8 1.4 Iris-versicolor
5.1 3.4 1.5 0.2 Iris-setosa
6.1 3.0 4.9 1.8 Iris-virginica
5.7 3.0 4.2 1.2 Iris-versicolor
5.0 3.4 1.5 0.2 Iris-setosa
5.0 3.5 1.6 0.6 Iris-setosa
7.7 3.8 6.7 2.2 Iris-virginica
4.9 3.1 1.5 0.1 Iris-setosa
6.0 2.2 4.0 1.0 Iris-versicolor
6.8 3.0 5.5 2.1 Iris-virginica
5.1 2.5 3.0 1.1 Iris-versicolor
6.5 3.2 5.1 2.0 Iris-virginica
4.7 3.2 1.3 0.2 Iris-setosa
6.6 3.0 4.4 1.4 Iris-versicolor
6.7 3.0 5.0 1.7 Iris-versicolor
4.8 3.0 1.4 0.3 Iris-setosa
5.1 3.8 1.5 0.3 Iris-setosa
7.7 2.6 6.9 2.3 Iris-virginica
5.1 3.8 1.6 0.2 Iris-setosa
5.0 2.0 3.5 1.0 Iris-versicolor
7.7 3.0 6.1 2.3 Iris-virginica
6.5 3.0 5.8 2.2 Iris-virginica
5.8 4.0 1.2 0.2 Iris-setosa
5.4 3.4 1.5 0.4 Iris-setosa
6.2 2.2 4.5 1.5 Iris-versicolor
5.7 2.8 4.5 1.3 Iris-versicolor
5.5 2.5 4.0 1.3 Iris-versicolor
7.3 2.9 6.3 1.8 Iris-virginica
5.6 3.0 4.5 1.5 Iris-versicolor
6.2 2.8 4.8 1.8 Iris-virginica
4.3 3.0 1.1 0.1 Iris-setosa
5.8 2.7 3.9 1.2 Iris-versicolor
7.9 3.8 6.4 2.0 Iris-virginica
6.2 2.9 4.3 1.3 Iris-versicolor
4.9 2.5 4.5 1.7 Iris-virginica
4.9 3.6 1.4 0.1 Iris-setosa
5.2 3.4 1.4 0.2 Iris-setosa
6.0 2.7 5.1 1.6 Iris-versicolor
6.9 3.1 5.4 2.1 Iris-virginica
4.8 3.1 1.6 0.2 Iris-setosa
6.7 3.1 5.6 2.4 Iris-virginica
6.3 3.3 4.7 1.6 Iris-versicolor
5.2 4.1 1.5 0.1 Iris-setosa
5.4 3.9 1.3 0.4 Iris-setosa
4.9 2.4 3.3 1.0 Iris-versicolor
5.5 2.4 3.7 1.0 Iris-versicolor
5.1 3.5 1.4 0.3 Iris-setosa
6.1 3.0 4.6 1.4 Iris-versicolor
5.1 3.3 1.7 0.5 Iris-setosa
4.4 3.0 1.3 0.2 Iris-setosa
5.9 3.2 4.8 1.8 Iris-versicolor
4.7 3.2 1.6 0.2 Iris-setosa
6.9 3.1 5.1 2.3 Iris-virginica
5.4 3.9 1.7 0.4 Iris-setosa
5.8 2.7 4.1 1.0 Iris-versicolor
6.1 2.8 4.7 1.2 Iris-versicolor
6.0 3.0 4.8 1.8 Iris-virginica
6.7 2.5 5.8 1.8 Iris-virginica
================================================
FILE: project/build.properties
================================================
sbt.version=0.13.9
================================================
FILE: project/plugins.sbt
================================================
addSbtPlugin("com.eed3si9n" % "sbt-assembly" % "0.13.0")
addSbtPlugin("no.arktekk.sbt" % "aether-deploy" % "0.14")
addSbtPlugin("org.xerial.sbt" % "sbt-sonatype" % "0.2.1")
addSbtPlugin("com.typesafe.sbt" % "sbt-pgp" % "0.8.3")
================================================
FILE: sbt
================================================
#!/usr/bin/env bash
#
# A more capable sbt runner, coincidentally also called sbt.
# Author: Paul Phillips
# todo - make this dynamic
declare -r sbt_release_version="0.13.8"
declare -r sbt_unreleased_version="0.13.9-M1"
declare -r buildProps="project/build.properties"
declare sbt_jar sbt_dir sbt_create sbt_version
declare scala_version sbt_explicit_version
declare verbose noshare batch trace_level log_level
declare sbt_saved_stty debugUs
echoerr () { echo >&2 "$@"; }
vlog () { [[ -n "$verbose" ]] && echoerr "$@"; }
# spaces are possible, e.g. sbt.version = 0.13.0
build_props_sbt () {
[[ -r "$buildProps" ]] && \
grep '^sbt\.version' "$buildProps" | tr '=\r' ' ' | awk '{ print $2; }'
}
update_build_props_sbt () {
local ver="$1"
local old="$(build_props_sbt)"
[[ -r "$buildProps" ]] && [[ "$ver" != "$old" ]] && {
perl -pi -e "s/^sbt\.version\b.*\$/sbt.version=${ver}/" "$buildProps"
grep -q '^sbt.version[ =]' "$buildProps" || printf "\nsbt.version=%s\n" "$ver" >> "$buildProps"
vlog "!!!"
vlog "!!! Updated file $buildProps setting sbt.version to: $ver"
vlog "!!! Previous value was: $old"
vlog "!!!"
}
}
set_sbt_version () {
sbt_version="${sbt_explicit_version:-$(build_props_sbt)}"
[[ -n "$sbt_version" ]] || sbt_version=$sbt_release_version
export sbt_version
}
# restore stty settings (echo in particular)
onSbtRunnerExit() {
[[ -n "$sbt_saved_stty" ]] || return
vlog ""
vlog "restoring stty: $sbt_saved_stty"
stty "$sbt_saved_stty"
unset sbt_saved_stty
}
# save stty and trap exit, to ensure echo is reenabled if we are interrupted.
trap onSbtRunnerExit EXIT
sbt_saved_stty="$(stty -g 2>/dev/null)"
vlog "Saved stty: $sbt_saved_stty"
# this seems to cover the bases on OSX, and someone will
# have to tell me about the others.
get_script_path () {
local path="$1"
[[ -L "$path" ]] || { echo "$path" ; return; }
local target="$(readlink "$path")"
if [[ "${target:0:1}" == "/" ]]; then
echo "$target"
else
echo "${path%/*}/$target"
fi
}
die() {
echo "Aborting: $@"
exit 1
}
make_url () {
version="$1"
case "$version" in
0.7.*) echo "http://simple-build-tool.googlecode.com/files/sbt-launch-0.7.7.jar" ;;
0.10.* ) echo "$sbt_launch_repo/org.scala-tools.sbt/sbt-launch/$version/sbt-launch.jar" ;;
0.11.[12]) echo "$sbt_launch_repo/org.scala-tools.sbt/sbt-launch/$version/sbt-launch.jar" ;;
*) echo "$sbt_launch_repo/org.scala-sbt/sbt-launch/$version/sbt-launch.jar" ;;
esac
}
init_default_option_file () {
local overriding_var="${!1}"
local default_file="$2"
if [[ ! -r "$default_file" && "$overriding_var" =~ ^@(.*)$ ]]; then
local envvar_file="${BASH_REMATCH[1]}"
if [[ -r "$envvar_file" ]]; then
default_file="$envvar_file"
fi
fi
echo "$default_file"
}
declare -r cms_opts="-XX:+CMSClassUnloadingEnabled -XX:+UseConcMarkSweepGC"
declare -r jit_opts="-XX:ReservedCodeCacheSize=256m -XX:+TieredCompilation"
declare -r default_jvm_opts_common="-Xms512m -Xmx1536m -Xss2m $jit_opts $cms_opts"
declare -r noshare_opts="-Dsbt.global.base=project/.sbtboot -Dsbt.boot.directory=project/.boot -Dsbt.ivy.home=project/.ivy"
declare -r latest_28="2.8.2"
declare -r latest_29="2.9.3"
declare -r latest_210="2.10.5"
declare -r latest_211="2.11.7"
declare -r script_path="$(get_script_path "$BASH_SOURCE")"
declare -r script_name="${script_path##*/}"
# some non-read-onlies set with defaults
declare java_cmd="java"
declare sbt_opts_file="$(init_default_option_file SBT_OPTS .sbtopts)"
declare jvm_opts_file="$(init_default_option_file JVM_OPTS .jvmopts)"
declare sbt_launch_repo="http://repo.typesafe.com/typesafe/ivy-releases"
# pull -J and -D options to give to java.
declare -a residual_args
declare -a java_args
declare -a scalac_args
declare -a sbt_commands
# args to jvm/sbt via files or environment variables
declare -a extra_jvm_opts extra_sbt_opts
addJava () {
vlog "[addJava] arg = '$1'"
java_args+=("$1")
}
addSbt () {
vlog "[addSbt] arg = '$1'"
sbt_commands+=("$1")
}
setThisBuild () {
vlog "[addBuild] args = '$@'"
local key="$1" && shift
addSbt "set $key in ThisBuild := $@"
}
addScalac () {
vlog "[addScalac] arg = '$1'"
scalac_args+=("$1")
}
addResidual () {
vlog "[residual] arg = '$1'"
residual_args+=("$1")
}
addResolver () {
addSbt "set resolvers += $1"
}
addDebugger () {
addJava "-Xdebug"
addJava "-Xrunjdwp:transport=dt_socket,server=y,suspend=n,address=$1"
}
setScalaVersion () {
[[ "$1" == *"-SNAPSHOT" ]] && addResolver 'Resolver.sonatypeRepo("snapshots")'
addSbt "++ $1"
}
setJavaHome () {
java_cmd="$1/bin/java"
setThisBuild javaHome "Some(file(\"$1\"))"
export JAVA_HOME="$1"
export JDK_HOME="$1"
export PATH="$JAVA_HOME/bin:$PATH"
}
setJavaHomeQuietly () {
addSbt warn
setJavaHome "$1"
addSbt info
}
# if set, use JDK_HOME/JAVA_HOME over java found in path
if [[ -e "$JDK_HOME/lib/tools.jar" ]]; then
setJavaHomeQuietly "$JDK_HOME"
elif [[ -e "$JAVA_HOME/bin/java" ]]; then
setJavaHomeQuietly "$JAVA_HOME"
fi
# directory to store sbt launchers
declare sbt_launch_dir="$HOME/.sbt/launchers"
[[ -d "$sbt_launch_dir" ]] || mkdir -p "$sbt_launch_dir"
[[ -w "$sbt_launch_dir" ]] || sbt_launch_dir="$(mktemp -d -t sbt_extras_launchers.XXXXXX)"
java_version () {
local version=$("$java_cmd" -version 2>&1 | grep -E -e '(java|openjdk) version' | awk '{ print $3 }' | tr -d \")
vlog "Detected Java version: $version"
echo "${version:2:1}"
}
# MaxPermSize critical on pre-8 jvms but incurs noisy warning on 8+
default_jvm_opts () {
local v="$(java_version)"
if [[ $v -ge 8 ]]; then
echo "$default_jvm_opts_common"
else
echo "-XX:MaxPermSize=384m $default_jvm_opts_common"
fi
}
build_props_scala () {
if [[ -r "$buildProps" ]]; then
versionLine="$(grep '^build.scala.versions' "$buildProps")"
versionString="${versionLine##build.scala.versions=}"
echo "${versionString%% .*}"
fi
}
execRunner () {
# print the arguments one to a line, quoting any containing spaces
vlog "# Executing command line:" && {
for arg; do
if [[ -n "$arg" ]]; then
if printf "%s\n" "$arg" | grep -q ' '; then
printf >&2 "\"%s\"\n" "$arg"
else
printf >&2 "%s\n" "$arg"
fi
fi
done
vlog ""
}
[[ -n "$batch" ]] && exec /dev/null; then
curl --fail --silent --location "$url" --output "$jar"
elif which wget >/dev/null; then
wget --quiet -O "$jar" "$url"
fi
} && [[ -r "$jar" ]]
}
acquire_sbt_jar () {
sbt_url="$(jar_url "$sbt_version")"
sbt_jar="$(jar_file "$sbt_version")"
[[ -r "$sbt_jar" ]] || download_url "$sbt_url" "$sbt_jar"
}
usage () {
cat < display stack traces with a max of frames (default: -1, traces suppressed)
-debug-inc enable debugging log for the incremental compiler
-no-colors disable ANSI color codes
-sbt-create start sbt even if current directory contains no sbt project
-sbt-dir path to global settings/plugins directory (default: ~/.sbt/)
-sbt-boot path to shared boot directory (default: ~/.sbt/boot in 0.11+)
-ivy path to local Ivy repository (default: ~/.ivy2)
-no-share use all local caches; no sharing
-offline put sbt in offline mode
-jvm-debug Turn on JVM debugging, open at the given port.
-batch Disable interactive mode
-prompt Set the sbt prompt; in expr, 's' is the State and 'e' is Extracted
# sbt version (default: sbt.version from $buildProps if present, otherwise $sbt_release_version)
-sbt-force-latest force the use of the latest release of sbt: $sbt_release_version
-sbt-version use the specified version of sbt (default: $sbt_release_version)
-sbt-dev use the latest pre-release version of sbt: $sbt_unreleased_version
-sbt-jar use the specified jar as the sbt launcher
-sbt-launch-dir directory to hold sbt launchers (default: ~/.sbt/launchers)
-sbt-launch-repo repo url for downloading sbt launcher jar (default: $sbt_launch_repo)
# scala version (default: as chosen by sbt)
-28 use $latest_28
-29 use $latest_29
-210 use $latest_210
-211 use $latest_211
-scala-home use the scala build at the specified directory
-scala-version use the specified version of scala
-binary-version use the specified scala version when searching for dependencies
# java version (default: java from PATH, currently $(java -version 2>&1 | grep version))
-java-home alternate JAVA_HOME
# passing options to the jvm - note it does NOT use JAVA_OPTS due to pollution
# The default set is used if JVM_OPTS is unset and no -jvm-opts file is found
$(default_jvm_opts)
JVM_OPTS environment variable holding either the jvm args directly, or
the reference to a file containing jvm args if given path is prepended by '@' (e.g. '@/etc/jvmopts')
Note: "@"-file is overridden by local '.jvmopts' or '-jvm-opts' argument.
-jvm-opts file containing jvm args (if not given, .jvmopts in project root is used if present)
-Dkey=val pass -Dkey=val directly to the jvm
-J-X pass option -X directly to the jvm (-J is stripped)
# passing options to sbt, OR to this runner
SBT_OPTS environment variable holding either the sbt args directly, or
the reference to a file containing sbt args if given path is prepended by '@' (e.g. '@/etc/sbtopts')
Note: "@"-file is overridden by local '.sbtopts' or '-sbt-opts' argument.
-sbt-opts file containing sbt args (if not given, .sbtopts in project root is used if present)
-S-X add -X to sbt's scalacOptions (-S is stripped)
EOM
}
process_args ()
{
require_arg () {
local type="$1"
local opt="$2"
local arg="$3"
if [[ -z "$arg" ]] || [[ "${arg:0:1}" == "-" ]]; then
die "$opt requires <$type> argument"
fi
}
while [[ $# -gt 0 ]]; do
case "$1" in
-h|-help) usage; exit 1 ;;
-v) verbose=true && shift ;;
-d) addSbt "--debug" && addSbt debug && shift ;;
-w) addSbt "--warn" && addSbt warn && shift ;;
-q) addSbt "--error" && addSbt error && shift ;;
-x) debugUs=true && shift ;;
-trace) require_arg integer "$1" "$2" && trace_level="$2" && shift 2 ;;
-ivy) require_arg path "$1" "$2" && addJava "-Dsbt.ivy.home=$2" && shift 2 ;;
-no-colors) addJava "-Dsbt.log.noformat=true" && shift ;;
-no-share) noshare=true && shift ;;
-sbt-boot) require_arg path "$1" "$2" && addJava "-Dsbt.boot.directory=$2" && shift 2 ;;
-sbt-dir) require_arg path "$1" "$2" && sbt_dir="$2" && shift 2 ;;
-debug-inc) addJava "-Dxsbt.inc.debug=true" && shift ;;
-offline) addSbt "set offline := true" && shift ;;
-jvm-debug) require_arg port "$1" "$2" && addDebugger "$2" && shift 2 ;;
-batch) batch=true && shift ;;
-prompt) require_arg "expr" "$1" "$2" && setThisBuild shellPrompt "(s => { val e = Project.extract(s) ; $2 })" && shift 2 ;;
-sbt-create) sbt_create=true && shift ;;
-sbt-jar) require_arg path "$1" "$2" && sbt_jar="$2" && shift 2 ;;
-sbt-version) require_arg version "$1" "$2" && sbt_explicit_version="$2" && shift 2 ;;
-sbt-force-latest) sbt_explicit_version="$sbt_release_version" && shift ;;
-sbt-dev) sbt_explicit_version="$sbt_unreleased_version" && shift ;;
-sbt-launch-dir) require_arg path "$1" "$2" && sbt_launch_dir="$2" && shift 2 ;;
-sbt-launch-repo) require_arg path "$1" "$2" && sbt_launch_repo="$2" && shift 2 ;;
-scala-version) require_arg version "$1" "$2" && setScalaVersion "$2" && shift 2 ;;
-binary-version) require_arg version "$1" "$2" && setThisBuild scalaBinaryVersion "\"$2\"" && shift 2 ;;
-scala-home) require_arg path "$1" "$2" && setThisBuild scalaHome "Some(file(\"$2\"))" && shift 2 ;;
-java-home) require_arg path "$1" "$2" && setJavaHome "$2" && shift 2 ;;
-sbt-opts) require_arg path "$1" "$2" && sbt_opts_file="$2" && shift 2 ;;
-jvm-opts) require_arg path "$1" "$2" && jvm_opts_file="$2" && shift 2 ;;
-D*) addJava "$1" && shift ;;
-J*) addJava "${1:2}" && shift ;;
-S*) addScalac "${1:2}" && shift ;;
-28) setScalaVersion "$latest_28" && shift ;;
-29) setScalaVersion "$latest_29" && shift ;;
-210) setScalaVersion "$latest_210" && shift ;;
-211) setScalaVersion "$latest_211" && shift ;;
--debug) addSbt debug && addResidual "$1" && shift ;;
--warn) addSbt warn && addResidual "$1" && shift ;;
--error) addSbt error && addResidual "$1" && shift ;;
*) addResidual "$1" && shift ;;
esac
done
}
# process the direct command line arguments
process_args "$@"
# skip #-styled comments and blank lines
readConfigFile() {
while read line; do
[[ $line =~ ^# ]] || [[ -z $line ]] || echo "$line"
done < "$1"
}
# if there are file/environment sbt_opts, process again so we
# can supply args to this runner
if [[ -r "$sbt_opts_file" ]]; then
vlog "Using sbt options defined in file $sbt_opts_file"
while read opt; do extra_sbt_opts+=("$opt"); done < <(readConfigFile "$sbt_opts_file")
elif [[ -n "$SBT_OPTS" && ! ("$SBT_OPTS" =~ ^@.*) ]]; then
vlog "Using sbt options defined in variable \$SBT_OPTS"
extra_sbt_opts=( $SBT_OPTS )
else
vlog "No extra sbt options have been defined"
fi
[[ -n "${extra_sbt_opts[*]}" ]] && process_args "${extra_sbt_opts[@]}"
# reset "$@" to the residual args
set -- "${residual_args[@]}"
argumentCount=$#
# set sbt version
set_sbt_version
# only exists in 0.12+
setTraceLevel() {
case "$sbt_version" in
"0.7."* | "0.10."* | "0.11."* ) echoerr "Cannot set trace level in sbt version $sbt_version" ;;
*) setThisBuild traceLevel $trace_level ;;
esac
}
# set scalacOptions if we were given any -S opts
[[ ${#scalac_args[@]} -eq 0 ]] || addSbt "set scalacOptions in ThisBuild += \"${scalac_args[@]}\""
# Update build.properties on disk to set explicit version - sbt gives us no choice
[[ -n "$sbt_explicit_version" ]] && update_build_props_sbt "$sbt_explicit_version"
vlog "Detected sbt version $sbt_version"
[[ -n "$scala_version" ]] && vlog "Overriding scala version to $scala_version"
# no args - alert them there's stuff in here
(( argumentCount > 0 )) || {
vlog "Starting $script_name: invoke with -help for other options"
residual_args=( shell )
}
# verify this is an sbt dir or -create was given
[[ -r ./build.sbt || -d ./project || -n "$sbt_create" ]] || {
cat < implements java.io.Serializable {
private static final long serialVersionUID = 123L;
public F first;
public S second;
/**
* Class constructor specifying the first and second number to create
*
* @param first
* first number
* @param second
* second number
*/
public GenericPair(F first, S second) {
this.first = first;
this.second = second;
}
/**
* The method gets first number
*
* @return first number
*/
public F getFirst() {
return first;
}
/**
* The method sets first number
*
* @param fisrt
* first number
*/
public void setFirst(F first) {
this.first = first;
}
/**
* The method gets second number
*
* @return second number
*/
public S getSecond() {
return second;
}
/**
* The method sets second number
*
* @param second
* second number
*/
public void setSecond(S second) {
this.second = second;
}
@Override
public String toString() {
return first + "," + second;
}
@SuppressWarnings("unchecked")
public boolean equals(Object o) {
if (!(o instanceof GenericPair, ?>))
return false;
GenericPair p = (GenericPair)o;
return (p.first).equals(first) && (p.second).equals(second);
}
public int hashCode() {
return 17 + first.hashCode() * 31 + second.hashCode();
}
}
================================================
FILE: src/main/java/com/etsy/conjecture/PrimitivePair.java
================================================
package com.etsy.conjecture;
/**
* PrimitivePair is JavaBean
*
* @author Josh Attenberg
*/
public class PrimitivePair implements java.io.Serializable {
private static final long serialVersionUID = 1234L;
public double first;
public double second;
/**
* Class constructor specifying the first and second number to create
*
* @param first
* first number
* @param second
* second number
*/
public PrimitivePair(double first, double second) {
this.first = first;
this.second = second;
}
/**
* The method gets first number
*
* @return first number
*/
public double getFirst() {
return first;
}
/**
* The method sets first number
*
* @param fisrt
* first number
*/
public void setFirst(double fisrt) {
this.first = fisrt;
}
/**
* The method gets second number
*
* @return second number
*/
public double getSecond() {
return second;
}
/**
* The method sets second number
*
* @param second
* second number
*/
public void setSecond(double second) {
this.second = second;
}
@Override
public String toString() {
return first + "," + second;
}
@Override
public boolean equals(Object o) {
if (!(o instanceof PrimitivePair))
return false;
PrimitivePair p = (PrimitivePair)o;
return p.first == first && p.second == second;
}
@Override
public int hashCode() {
return (17 + Utilities.doubleHash(first)) * 31
+ Utilities.doubleHash(second);
}
}
================================================
FILE: src/main/java/com/etsy/conjecture/Utilities.java
================================================
package com.etsy.conjecture;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.Collection;
import java.util.Collections;
import java.util.Comparator;
import java.util.List;
import java.util.Map;
import java.util.StringTokenizer;
import org.apache.commons.lang.StringUtils;
import com.google.common.hash.*;
import com.google.common.collect.Lists;
/**
* class of static data science utility methods
*
* @author jattenberg
*
*/
public class Utilities {
public static final double SMALL = 1e-10;
public static final HashFunction HASHER = Hashing.md5();
public static final double ROOT2 = Math.sqrt(2d);
public static final double LOG2 = Math.log(2.);
private Utilities() {
}
public static String cleanLine(String line) {
StringBuffer buffer = new StringBuffer();
for (int i = 0; i < line.length(); i++) {
char c = line.charAt(i);
if (c < 128 && Character.isLetter(c)) {
buffer.append(c);
} else {
buffer.append(' ');
}
}
return buffer.toString().toLowerCase();
}
public static String cleanLineRobust(String input, String separator,
boolean ignoreNumbers) {
StringBuilder buff = new StringBuilder();
StringTokenizer tokenizer = new StringTokenizer(input,
" +.,~\\<>\\$?!:;(){}|" + "\b\t\n\f\r\"\'\\\\/\\=\\&\\%\\_");
while (tokenizer.hasMoreTokens()) {
String token = tokenizer.nextToken();
token = token.replaceAll("-{2,}", "-");
token = token.replaceAll("^-", "");
token = token.replaceAll("-$", "");
if (token.length() < 2
|| (ignoreNumbers && StringUtils.containsAny(token,
"0123456789")))
continue;
buff.append(token + separator);
}
int index = buff.lastIndexOf(separator);
if (index >= 0)
buff.delete(index, buff.length());
return buff.toString();
}
public static String checkNotBlank(String s) {
if (StringUtils.isBlank(s)) {
throw new IllegalArgumentException("Argument cannot be blank");
}
return s;
}
public static List checkNotBlank(List S) {
for (String s : S)
checkNotBlank(s);
return S;
}
public static String[] checkNotBlank(String[] S) {
for (String s : S)
checkNotBlank(s);
return S;
}
public static double stringInnerProduct(Map coefficients,
Collection input) {
double output = 0;
for (String token : input)
output += coefficients.containsKey(token) ? coefficients.get(token)
: 0;
return output;
}
public static double sigmoid(double operand) {
return 1. / (1. + Math.exp(-operand));
}
/**
* derivative of the sigmoid function
*/
public static double dsigmoid(double operand) {
return Math.exp(operand) / Math.pow(1. + Math.exp(operand), 2.);
}
/**
* returns the strings in input in sorted order
*
* @param input
* @return
*/
public static String sortTerms(String input) {
return sortTerms(input, "\\s+");
}
public static String sortTerms(String input, String delim) {
String[] terms = input.split(delim);
Arrays.sort(terms);
return StringUtils.join(terms, delim);
}
public final static String cleanText(String tmp, int maxlen) {
StringTokenizer tok = new StringTokenizer(tmp,
" +.,~\\<>\\$?!:;(){}|-0123456789\b\t\n\f\r\"\'\\\\/\\=\\&\\%\\_");
StringBuilder buff = new StringBuilder();
while (tok.hasMoreTokens()) {
String out = tok.nextToken();
if (out.length() < 2 || out.length() > maxlen)
continue;
buff.append(out + " ");
}
return buff.toString();
}
public final static List grams(String input, int[] gramSizes,
String separator) {
List out = Lists.newArrayList();
StringBuilder buff = new StringBuilder();
String[] tokens = StringUtils.split(input);
for (int i = 0; i < tokens.length; i++) {
String token = tokens[i];
for (int len : gramSizes) {
if (len > i + 1)
continue;
if (len == 1) {
out.add(token);
continue;
}
buff.setLength(0);
for (int k = len - 1; k > 0; k--)
buff.append(tokens[i - k] + separator);
buff.append(token);
out.add(buff.toString());
}
}
return out;
}
public static final boolean floatingPointEquals(double a, double b) {
return (a - b < SMALL) && (b - a < SMALL);
}
public static int doubleHash(double d) {
long t = Double.doubleToLongBits(d);
return (int)(t ^ (t >>> 32));
}
public static double logistic(double x) {
return 1d / (1 + Math.exp(-x));
}
static class ValueComparator> implements
Comparator> {
boolean reverse;
public ValueComparator(boolean reverse) {
this.reverse = reverse;
}
public int compare(Map.Entry a, Map.Entry b) {
int res = a.getValue().compareTo(b.getValue());
return reverse ? -res : res;
}
}
public static > ArrayList orderKeysByValue(
Map map) {
return orderKeysByValue(map, false);
}
public static > ArrayList orderKeysByValue(
Map map, boolean reverse) {
ArrayList> keys = new ArrayList>();
keys.addAll(map.entrySet());
Collections.sort(keys, new ValueComparator(reverse));
ArrayList res = new ArrayList();
for (int i = 0; i < keys.size(); i++) {
res.add(keys.get(i).getKey());
}
return res;
}
public static > List topKeysByValue(
Map map, int n) {
ArrayList keys = orderKeysByValue(map, true);
ArrayList res = new ArrayList(n);
for (int i = 0; i < n && i < keys.size(); i++) {
res.add(keys.get(i));
}
return res;
}
}
================================================
FILE: src/main/java/com/etsy/conjecture/data/AbstractInstance.java
================================================
package com.etsy.conjecture.data;
import java.util.Collection;
import java.util.List;
import java.util.Map;
public abstract class AbstractInstance> {
protected static final String SEP = "___";
public String id;
public String supporting_data;
protected double weight;
StringKeyedVector vector;
public AbstractInstance() {
this(new StringKeyedVector(), 1.0);
}
public AbstractInstance(double weight) {
this(new StringKeyedVector(), weight);
}
public AbstractInstance(StringKeyedVector skv) {
this(skv, 1.0);
}
public AbstractInstance(StringKeyedVector skv, double weight) {
this.vector = skv;
this.weight = weight;
}
public AbstractInstance(Map map) {
this(map, 1.0);
}
public AbstractInstance(Map map, double weight) {
this.vector = new StringKeyedVector(map);
this.weight = weight;
}
@SuppressWarnings("unchecked")
public T setWeight(double weight) {
this.weight = weight;
return (T)this;
}
public double getWeight() {
return weight;
}
public String getId() {
return id;
}
public StringKeyedVector getVector() {
return vector;
}
public void setSupportingData(String s) {
supporting_data = s;
}
public String getSupportingData() {
return supporting_data;
}
@SuppressWarnings("unchecked")
public T setCoordinate(String id, double value) {
vector.setCoordinate(id, value);
return (T)this;
}
@SuppressWarnings("unchecked")
public T addToCoordinate(String id, double value) {
vector.addToCoordinate(id, value);
return (T)this;
}
@SuppressWarnings("unchecked")
public T setId(String id) {
this.id = id;
return (T)this;
}
/*
* (non-Javadoc)
*
* @see com.etsy.conjecture.data.InstanceInterface#addTerm(java.lang.String)
*/
@SuppressWarnings("unchecked")
public T addTerm(String term) {
addTerm(term, 1.);
return (T)this;
}
/*
* (non-Javadoc)
*
* @see com.etsy.conjecture.data.InstanceInterface#addTerm(java.lang.String,
* double)
*/
@SuppressWarnings("unchecked")
public T addTerm(String term, double featureWeight) {
addToCoordinate(term, featureWeight);
return (T)this;
}
/*
* (non-Javadoc)
*
* @see
* com.etsy.conjecture.data.InstanceInterface#addTermWithNamespace(java.
* lang.String, java.lang.String)
*/
@SuppressWarnings("unchecked")
public T addTermWithNamespace(String term, String namespace) {
addTermWithNamespace(term, namespace, 1);
return (T)this;
}
/*
* (non-Javadoc)
*
* @see
* com.etsy.conjecture.data.InstanceInterface#addTermWithNamespace(java.
* lang.String, java.lang.String, double)
*/
@SuppressWarnings("unchecked")
public T addTermWithNamespace(String term, String namespace,
double featureWeight) {
addToCoordinate(namespace + SEP + term, featureWeight);
return (T)this;
}
/*
* (non-Javadoc)
*
* @see
* com.etsy.conjecture.data.InstanceInterface#addTerms(java.util.Collection,
* double)
*/
@SuppressWarnings("unchecked")
public T addTerms(Collection terms, double featureWeight) {
for (String term : terms) {
addToCoordinate(term, featureWeight);
}
return (T)this;
}
/*
* (non-Javadoc)
*
* @see
* com.etsy.conjecture.data.InstanceInterface#addTerms(java.util.Collection)
*/
@SuppressWarnings("unchecked")
public T addTerms(Collection terms) {
addTerms(terms, 1.);
return (T)this;
}
/*
* (non-Javadoc)
*
* @see
* com.etsy.conjecture.data.InstanceInterface#addTermsWithNamespace(java
* .util.Collection, java.lang.String, double)
*/
@SuppressWarnings("unchecked")
public T addTermsWithNamespace(Collection terms, String namespace,
double featureWeight) {
for (String term : terms) {
addTermWithNamespace(term, namespace, featureWeight);
}
return (T)this;
}
/*
* (non-Javadoc)
*
* @see
* com.etsy.conjecture.data.InstanceInterface#addTermsWithNamespace(java
* .util.Collection, java.lang.String)
*/
@SuppressWarnings("unchecked")
public T addTermsWithNamespace(Collection terms, String namespace) {
addTermsWithNamespace(terms, namespace, 1.);
return (T)this;
}
/*
* (non-Javadoc)
*
* @see
* com.etsy.conjecture.data.InstanceInterface#addTerms(java.lang.String[],
* double)
*/
@SuppressWarnings("unchecked")
public T addTerms(String[] terms, double featureWeight) {
for (String term : terms) {
addToCoordinate(term, featureWeight);
}
return (T)this;
}
/*
* (non-Javadoc)
*
* @see
* com.etsy.conjecture.data.InstanceInterface#addTerms(java.lang.String[])
*/
@SuppressWarnings("unchecked")
public T addTerms(String[] terms) {
addTerms(terms, 1.);
return (T)this;
}
/*
* (non-Javadoc)
*
* @see
* com.etsy.conjecture.data.InstanceInterface#addTermsWithNamespace(java
* .lang.String[], java.lang.String, double)
*/
@SuppressWarnings("unchecked")
public T addTermsWithNamespace(String[] terms, String namespace,
double featureWeight) {
for (String term : terms) {
addTermWithNamespace(term, namespace, featureWeight);
}
return (T)this;
}
/*
* (non-Javadoc)
*
* @see
* com.etsy.conjecture.data.InstanceInterface#addTermsWithNamespace(java
* .lang.String[], java.lang.String)
*/
@SuppressWarnings("unchecked")
public T addTermsWithNamespace(String[] terms, String namespace) {
addTermsWithNamespace(terms, namespace, 1.);
return (T)this;
}
/*
* (non-Javadoc)
*
* @see
* com.etsy.conjecture.data.InstanceInterface#addTermsWithWeights(java.util
* .Map)
*/
@SuppressWarnings("unchecked")
public T addTermsWithWeights(Map termsWithWeights) {
for (String term : termsWithWeights.keySet()) {
addTerm(term, termsWithWeights.get(term));
}
return (T)this;
}
/*
* (non-Javadoc)
*
* @see
* com.etsy.conjecture.data.InstanceInterface#addTermsWithWeightsWithNamespace
* (java.util.Map, java.lang.String)
*/
@SuppressWarnings("unchecked")
public T addTermsWithWeightsWithNamespace(
Map termsWithWeights, String namespace) {
for (String term : termsWithWeights.keySet()) {
addTermWithNamespace(term, namespace, termsWithWeights.get(term));
}
return (T)this;
}
/*
* (non-Javadoc)
*
* @see
* com.etsy.conjecture.data.InstanceInterface#addNumericArrayWithNamespace
* (double[], java.lang.String)
*/
@SuppressWarnings("unchecked")
public T addNumericArrayWithNamespace(double[] array, String namespace) {
for (int i = 0; i < array.length; i++) {
addToCoordinate(namespace + SEP + i, array[i]);
}
return (T)this;
}
/*
* (non-Javadoc)
*
* @see com.etsy.conjecture.data.InstanceInterface#addNumericArray(double[])
*/
@SuppressWarnings("unchecked")
public T addNumericArray(double[] array) {
for (int i = 0; i < array.length; i++) {
addToCoordinate("" + i, array[i]);
}
return (T)this;
}
/*
* (non-Javadoc)
*
* @see
* com.etsy.conjecture.data.InstanceInterface#addNumericArrayWithNamespace
* (java.lang.Double[], java.lang.String)
*/
@SuppressWarnings("unchecked")
public T addNumericArrayWithNamespace(Double[] array, String namespace) {
for (int i = 0; i < array.length; i++) {
addToCoordinate(namespace + SEP + i, array[i]);
}
return (T)this;
}
/*
* (non-Javadoc)
*
* @see
* com.etsy.conjecture.data.InstanceInterface#addNumericArray(java.lang.
* Double[])
*/
@SuppressWarnings("unchecked")
public T addNumericArray(Double[] array) {
for (int i = 0; i < array.length; i++) {
addToCoordinate("" + i, array[i]);
}
return (T)this;
}
/*
* (non-Javadoc)
*
* @see
* com.etsy.conjecture.data.InstanceInterface#addNumericArrayWithNamespace
* (java.util.List, java.lang.String)
*/
@SuppressWarnings("unchecked")
public T addNumericArrayWithNamespace(List values, String namespace) {
for (int i = 0; i < values.size(); i++) {
addToCoordinate(namespace + SEP + i, values.get(i));
}
return (T)this;
}
/*
* (non-Javadoc)
*
* @see
* com.etsy.conjecture.data.InstanceInterface#addNumericArray(java.util.
* List)
*/
@SuppressWarnings("unchecked")
public T addNumericArray(List values) {
for (int i = 0; i < values.size(); i++) {
addToCoordinate("" + i, values.get(i));
}
return (T)this;
}
/*
* (non-Javadoc)
*
* @see
* com.etsy.conjecture.data.InstanceInterface#setNumericArrayWithNamespace
* (double[], java.lang.String)
*/
@SuppressWarnings("unchecked")
public T setNumericArrayWithNamespace(double[] array, String namespace) {
for (int i = 0; i < array.length; i++) {
addToCoordinate(namespace + SEP + i, array[i]);
}
return (T)this;
}
/*
* (non-Javadoc)
*
* @see com.etsy.conjecture.data.InstanceInterface#setNumericArray(double[])
*/
@SuppressWarnings("unchecked")
public T setNumericArray(double[] array) {
for (int i = 0; i < array.length; i++) {
addToCoordinate("" + i, array[i]);
}
return (T)this;
}
/*
* (non-Javadoc)
*
* @see
* com.etsy.conjecture.data.InstanceInterface#setNumericArrayWithNamespace
* (java.lang.Double[], java.lang.String)
*/
@SuppressWarnings("unchecked")
public T setNumericArrayWithNamespace(Double[] array, String namespace) {
for (int i = 0; i < array.length; i++) {
addToCoordinate(namespace + SEP + i, array[i]);
}
return (T)this;
}
/*
* (non-Javadoc)
*
* @see
* com.etsy.conjecture.data.InstanceInterface#setNumericArray(java.lang.
* Double[])
*/
@SuppressWarnings("unchecked")
public T setNumericArray(Double[] array) {
for (int i = 0; i < array.length; i++) {
addToCoordinate("" + i, array[i]);
}
return (T)this;
}
/*
* (non-Javadoc)
*
* @see
* com.etsy.conjecture.data.InstanceInterface#setNumericArrayWithNamespace
* (java.util.List, java.lang.String)
*/
@SuppressWarnings("unchecked")
public T setNumericArrayWithNamespace(List values, String namespace) {
for (int i = 0; i < values.size(); i++) {
addToCoordinate(namespace + SEP + i, values.get(i));
}
return (T)this;
}
/*
* (non-Javadoc)
*
* @see
* com.etsy.conjecture.data.InstanceInterface#setNumericArray(java.util.
* List)
*/
@SuppressWarnings("unchecked")
public T setNumericArray(List values) {
for (int i = 0; i < values.size(); i++) {
addToCoordinate("" + i, values.get(i));
}
return (T)this;
}
/*
* (non-Javadoc)
*
* @see com.etsy.conjecture.data.InstanceInterface#addIdField(long, double)
*/
@SuppressWarnings("unchecked")
public T addIdField(long id, double featureWeight) {
addToCoordinate("" + id, featureWeight);
return (T)this;
}
/*
* (non-Javadoc)
*
* @see com.etsy.conjecture.data.InstanceInterface#addIdField(long)
*/
@SuppressWarnings("unchecked")
public T addIdField(long id) {
addIdField(id, 1.);
return (T)this;
}
/*
* (non-Javadoc)
*
* @see
* com.etsy.conjecture.data.InstanceInterface#addIdFieldWithNamespace(long,
* double, java.lang.String)
*/
@SuppressWarnings("unchecked")
public T addIdFieldWithNamespace(long id, double featureWeight,
String namespace) {
addToCoordinate(namespace + SEP + id, featureWeight);
return (T)this;
}
/*
* (non-Javadoc)
*
* @see
* com.etsy.conjecture.data.InstanceInterface#addIdFieldWithNamespace(long,
* java.lang.String)
*/
@SuppressWarnings("unchecked")
public T addIdFieldWithNamespace(long id, String namespace) {
addIdFieldWithNamespace(id, 1., namespace);
return (T)this;
}
/*
* (non-Javadoc)
*
* @see com.etsy.conjecture.data.InstanceInterface#addIdField(int, double)
*/
@SuppressWarnings("unchecked")
public T addIdField(int id, double featureWeight) {
addToCoordinate("" + id, featureWeight);
return (T)this;
}
/*
* (non-Javadoc)
*
* @see com.etsy.conjecture.data.InstanceInterface#addIdField(int)
*/
@SuppressWarnings("unchecked")
public T addIdField(int id) {
addIdField(id, 1.);
return (T)this;
}
/*
* (non-Javadoc)
*
* @see
* com.etsy.conjecture.data.InstanceInterface#addIdFieldWithNamespace(int,
* double, java.lang.String)
*/
@SuppressWarnings("unchecked")
public T addIdFieldWithNamespace(int id, double featureWeight,
String namespace) {
addToCoordinate(namespace + SEP + id, featureWeight);
return (T)this;
}
/*
* (non-Javadoc)
*
* @see
* com.etsy.conjecture.data.InstanceInterface#addIdFieldWithNamespace(int,
* java.lang.String)
*/
@SuppressWarnings("unchecked")
public T addIdFieldWithNamespace(int id, String namespace) {
addIdFieldWithNamespace(id, 1., namespace);
return (T)this;
}
/*
* (non-Javadoc)
*
* @see com.etsy.conjecture.data.InstanceInterface#addIds(long[], double)
*/
@SuppressWarnings("unchecked")
public T addIds(long[] ids, double featureWeight) {
for (long id : ids) {
addToCoordinate("" + id, featureWeight);
}
return (T)this;
}
/*
* (non-Javadoc)
*
* @see com.etsy.conjecture.data.InstanceInterface#addIds(long[])
*/
@SuppressWarnings("unchecked")
public T addIds(long[] ids) {
addIds(ids, 1.);
return (T)this;
}
/*
* (non-Javadoc)
*
* @see com.etsy.conjecture.data.InstanceInterface#addIds(int[], double)
*/
@SuppressWarnings("unchecked")
public T addIds(int[] ids, double featureWeight) {
for (long id : ids) {
addToCoordinate("" + id, featureWeight);
}
return (T)this;
}
/*
* (non-Javadoc)
*
* @see com.etsy.conjecture.data.InstanceInterface#addIds(int[])
*/
@SuppressWarnings("unchecked")
public T addIds(int[] ids) {
addIds(ids, 1.);
return (T)this;
}
/*
* (non-Javadoc)
*
* @see
* com.etsy.conjecture.data.InstanceInterface#addIds(java.util.Collection,
* double)
*/
@SuppressWarnings("unchecked")
public T addIds(Collection ids, double featureWeight) {
for (long id : ids) {
addToCoordinate("" + id, featureWeight);
}
return (T)this;
}
/*
* (non-Javadoc)
*
* @see
* com.etsy.conjecture.data.InstanceInterface#addIds(java.util.Collection)
*/
@SuppressWarnings("unchecked")
public T addIds(Collection ids) {
addIds(ids, 1.);
return (T)this;
}
/*
* (non-Javadoc)
*
* @see
* com.etsy.conjecture.data.InstanceInterface#addIdsWithNamespace(long[],
* double, java.lang.String)
*/
@SuppressWarnings("unchecked")
public T addIdsWithNamespace(long[] ids, double featureWeight,
String namespace) {
for (long id : ids) {
addToCoordinate(namespace + SEP + id, featureWeight);
}
return (T)this;
}
/*
* (non-Javadoc)
*
* @see
* com.etsy.conjecture.data.InstanceInterface#addIdsWithNamespace(long[],
* java.lang.String)
*/
@SuppressWarnings("unchecked")
public T addIdsWithNamespace(long[] ids, String namespace) {
addIdsWithNamespace(ids, 1., namespace);
return (T)this;
}
/*
* (non-Javadoc)
*
* @see
* com.etsy.conjecture.data.InstanceInterface#addIdsWithNamespace(int[],
* double, java.lang.String)
*/
@SuppressWarnings("unchecked")
public T addIdsWithNamespace(int[] ids, double featureWeight,
String namespace) {
for (int id : ids) {
addToCoordinate(namespace + SEP + id, featureWeight);
}
return (T)this;
}
/*
* (non-Javadoc)
*
* @see
* com.etsy.conjecture.data.InstanceInterface#addIdsWithNamespace(int[],
* java.lang.String)
*/
@SuppressWarnings("unchecked")
public T addIdsWithNamespace(int[] ids, String namespace) {
addIdsWithNamespace(ids, 1., namespace);
return (T)this;
}
/*
* (non-Javadoc)
*
* @see
* com.etsy.conjecture.data.InstanceInterface#addIdsWithNamespace(java.util
* .Collection, double, java.lang.String)
*/
@SuppressWarnings("unchecked")
public T addIdsWithNamespace(Collection ids, double featureWeight,
String namespace) {
for (Long id : ids) {
addToCoordinate(namespace + SEP + id, featureWeight);
}
return (T)this;
}
/*
* (non-Javadoc)
*
* @see
* com.etsy.conjecture.data.InstanceInterface#addIdsWithNamespace(java.util
* .Collection, java.lang.String)
*/
@SuppressWarnings("unchecked")
public T addIdsWithNamespace(Collection ids, String namespace) {
addIdsWithNamespace(ids, 1., namespace);
return (T)this;
}
/*
* (non-Javadoc)
*
* @see com.etsy.conjecture.data.InstanceInterface#setIdField(long, double)
*/
@SuppressWarnings("unchecked")
public T setIdField(long id, double featureWeight) {
addToCoordinate("" + id, featureWeight);
return (T)this;
}
/*
* (non-Javadoc)
*
* @see com.etsy.conjecture.data.InstanceInterface#setIdField(long)
*/
@SuppressWarnings("unchecked")
public T setIdField(long id) {
setIdField(id, 1.);
return (T)this;
}
/*
* (non-Javadoc)
*
* @see
* com.etsy.conjecture.data.InstanceInterface#setIdFieldWithNamespace(long,
* double, java.lang.String)
*/
@SuppressWarnings("unchecked")
public T setIdFieldWithNamespace(long id, double featureWeight,
String namespace) {
addToCoordinate(namespace + SEP + id, featureWeight);
return (T)this;
}
/*
* (non-Javadoc)
*
* @see
* com.etsy.conjecture.data.InstanceInterface#setIdFieldWithNamespace(long,
* java.lang.String)
*/
@SuppressWarnings("unchecked")
public T setIdFieldWithNamespace(long id, String namespace) {
setIdFieldWithNamespace(id, 1., namespace);
return (T)this;
}
/*
* (non-Javadoc)
*
* @see com.etsy.conjecture.data.InstanceInterface#setIdField(int, double)
*/
@SuppressWarnings("unchecked")
public T setIdField(int id, double featureWeight) {
addToCoordinate("" + id, featureWeight);
return (T)this;
}
/*
* (non-Javadoc)
*
* @see com.etsy.conjecture.data.InstanceInterface#setIdField(int)
*/
@SuppressWarnings("unchecked")
public T setIdField(int id) {
setIdField(id, 1.);
return (T)this;
}
/*
* (non-Javadoc)
*
* @see
* com.etsy.conjecture.data.InstanceInterface#setIdFieldWithNamespace(int,
* double, java.lang.String)
*/
@SuppressWarnings("unchecked")
public T setIdFieldWithNamespace(int id, double featureWeight,
String namespace) {
addToCoordinate(namespace + SEP + id, featureWeight);
return (T)this;
}
/*
* (non-Javadoc)
*
* @see
* com.etsy.conjecture.data.InstanceInterface#setIdFieldWithNamespace(int,
* java.lang.String)
*/
@SuppressWarnings("unchecked")
public T setIdFieldWithNamespace(int id, String namespace) {
setIdFieldWithNamespace(id, 1., namespace);
return (T)this;
}
/*
* (non-Javadoc)
*
* @see com.etsy.conjecture.data.InstanceInterface#setIds(long[], double)
*/
@SuppressWarnings("unchecked")
public T setIds(long[] ids, double featureWeight) {
for (long id : ids) {
addToCoordinate("" + id, featureWeight);
}
return (T)this;
}
/*
* (non-Javadoc)
*
* @see com.etsy.conjecture.data.InstanceInterface#setIds(long[])
*/
@SuppressWarnings("unchecked")
public T setIds(long[] ids) {
setIds(ids, 1.);
return (T)this;
}
/*
* (non-Javadoc)
*
* @see com.etsy.conjecture.data.InstanceInterface#setIds(int[], double)
*/
@SuppressWarnings("unchecked")
public T setIds(int[] ids, double featureWeight) {
for (long id : ids) {
addToCoordinate("" + id, featureWeight);
}
return (T)this;
}
/*
* (non-Javadoc)
*
* @see com.etsy.conjecture.data.InstanceInterface#setIds(int[])
*/
@SuppressWarnings("unchecked")
public T setIds(int[] ids) {
setIds(ids, 1.);
return (T)this;
}
/*
* (non-Javadoc)
*
* @see
* com.etsy.conjecture.data.InstanceInterface#setIds(java.util.Collection,
* double)
*/
@SuppressWarnings("unchecked")
public T setIds(Collection ids, double featureWeight) {
for (long id : ids) {
addToCoordinate("" + id, featureWeight);
}
return (T)this;
}
/*
* (non-Javadoc)
*
* @see
* com.etsy.conjecture.data.InstanceInterface#setIds(java.util.Collection)
*/
@SuppressWarnings("unchecked")
public T setIds(Collection ids) {
setIds(ids, 1.);
return (T)this;
}
/*
* (non-Javadoc)
*
* @see
* com.etsy.conjecture.data.InstanceInterface#setIdsWithNamespace(long[],
* double, java.lang.String)
*/
@SuppressWarnings("unchecked")
public T setIdsWithNamespace(long[] ids, double featureWeight,
String namespace) {
for (long id : ids) {
addToCoordinate(namespace + SEP + id, featureWeight);
}
return (T)this;
}
/*
* (non-Javadoc)
*
* @see
* com.etsy.conjecture.data.InstanceInterface#setIdsWithNamespace(long[],
* java.lang.String)
*/
@SuppressWarnings("unchecked")
public T setIdsWithNamespace(long[] ids, String namespace) {
setIdsWithNamespace(ids, 1., namespace);
return (T)this;
}
/*
* (non-Javadoc)
*
* @see
* com.etsy.conjecture.data.InstanceInterface#setIdsWithNamespace(int[],
* double, java.lang.String)
*/
@SuppressWarnings("unchecked")
public T setIdsWithNamespace(int[] ids, double featureWeight,
String namespace) {
for (int id : ids) {
addToCoordinate(namespace + SEP + id, featureWeight);
}
return (T)this;
}
/*
* (non-Javadoc)
*
* @see
* com.etsy.conjecture.data.InstanceInterface#setIdsWithNamespace(int[],
* java.lang.String)
*/
@SuppressWarnings("unchecked")
public T setIdsWithNamespace(int[] ids, String namespace) {
setIdsWithNamespace(ids, 1., namespace);
return (T)this;
}
/*
* (non-Javadoc)
*
* @see
* com.etsy.conjecture.data.InstanceInterface#setIdsWithNamespace(java.util
* .Collection, double, java.lang.String)
*/
@SuppressWarnings("unchecked")
public T setIdsWithNamespace(Collection ids, double featureWeight,
String namespace) {
for (Long id : ids) {
addToCoordinate(namespace + SEP + id, featureWeight);
}
return (T)this;
}
/*
* (non-Javadoc)
*
* @see
* com.etsy.conjecture.data.InstanceInterface#setIdsWithNamespace(java.util
* .Collection, java.lang.String)
*/
@SuppressWarnings("unchecked")
public T setIdsWithNamespace(Collection ids, String namespace) {
setIdsWithNamespace(ids, 1., namespace);
return (T)this;
}
}
================================================
FILE: src/main/java/com/etsy/conjecture/data/BinaryLabel.java
================================================
package com.etsy.conjecture.data;
import static com.google.common.base.Preconditions.checkArgument;
public class BinaryLabel extends RealValuedLabel {
private static final long serialVersionUID = 1L;
public BinaryLabel() {
super(0.0);
}
public BinaryLabel(double value) {
super(checkBinaryValue(value));
}
private static double checkBinaryValue(double value) {
checkArgument(value >= 0 && value <= 1,
"value must be in [0, 1], given: %s", value);
return value;
}
// {0,+1} -> {-1,+1}
public double getAsPlusMinus() {
return 2.0 * (getValue() - 0.5);
}
}
================================================
FILE: src/main/java/com/etsy/conjecture/data/BinaryLabeledInstance.java
================================================
package com.etsy.conjecture.data;
import java.util.Map;
/**
* TODO: when using method string all methods return a RealValueLabeledInstance
* think about how to avoid this while not using generic types
*/
public class BinaryLabeledInstance extends
AbstractInstance implements
LabeledInstance {
protected BinaryLabel label;
public BinaryLabel getLabel() {
return label;
}
public BinaryLabeledInstance() {
this(new BinaryLabel(0.0), 1.0);
}
public BinaryLabeledInstance(double label, Map instance) {
this(new BinaryLabel(label), instance, 1.0);
}
public BinaryLabeledInstance(double label, Map instance,
double weight) {
this(new BinaryLabel(label), instance, weight);
}
public BinaryLabeledInstance(double label, StringKeyedVector vec) {
this(new BinaryLabel(label), vec.getMap(), 1.0);
}
public BinaryLabeledInstance(double label, StringKeyedVector vec,
double weight) {
this(new BinaryLabel(label), vec.getMap(), weight);
}
public BinaryLabeledInstance(BinaryLabel label, Map instance) {
this(label, instance, 1.0);
}
public BinaryLabeledInstance(BinaryLabel label,
Map instance, double weight) {
super(instance, weight);
this.label = label;
}
public BinaryLabeledInstance(BinaryLabel label, StringKeyedVector vec) {
this(label, vec.getMap(), 1.0);
}
public BinaryLabeledInstance(BinaryLabel label, StringKeyedVector vec,
double weight) {
this(label, vec.getMap(), weight);
}
public BinaryLabeledInstance(double label) {
this(new BinaryLabel(label), 1.0);
}
public BinaryLabeledInstance(double label, double weight) {
this(new BinaryLabel(label), weight);
}
public BinaryLabeledInstance(BinaryLabel label) {
this(label, 1.0);
}
public BinaryLabeledInstance(BinaryLabel label, double weight) {
super(weight);
this.label = label;
}
}
================================================
FILE: src/main/java/com/etsy/conjecture/data/ByteArrayDoubleHashMap.java
================================================
package com.etsy.conjecture.data;
import gnu.trove.function.TDoubleFunction;
import gnu.trove.iterator.TObjectDoubleIterator;
import gnu.trove.map.hash.TObjectDoubleHashMap;
import java.io.IOException;
import java.io.ObjectInputStream;
import java.io.ObjectOutputStream;
import java.io.Serializable;
import java.io.UnsupportedEncodingException;
import java.util.AbstractMap;
import java.util.Arrays;
import java.util.HashSet;
import java.util.Iterator;
import java.util.Map;
import java.util.Set;
import com.esotericsoftware.kryo.Kryo;
import com.esotericsoftware.kryo.KryoSerializable;
import com.esotericsoftware.kryo.io.Input;
import com.esotericsoftware.kryo.io.Output;
public class ByteArrayDoubleHashMap implements Serializable, KryoSerializable,
Iterable>, Map {
private static final long serialVersionUID = -7070522686694887436L;
// - represent the sparse map by a mapping of coordinate name strings
// (feature names)
// to doubles.
protected TObjectDoubleHashMap map;
protected String keyEncoding;
protected float loadFactor;
protected double defaultValue;
public ByteArrayDoubleHashMap() {
this(10, 0.8f, 0.0);
}
public ByteArrayDoubleHashMap(int initialCapacity, float loadFactor,
double defaultValue) {
this(initialCapacity, loadFactor, "ASCII", defaultValue);
}
public ByteArrayDoubleHashMap(int initialCapacity, float loadFactor,
String keyEncoding, double defaultValue) {
this.map = new TByteArrayDoubleHashMap(initialCapacity, loadFactor,
defaultValue);
this.keyEncoding = keyEncoding;
this.loadFactor = loadFactor;
this.defaultValue = defaultValue;
}
public String byteArrayToString(byte[] b) {
try {
return new String(b, keyEncoding);
} catch (UnsupportedEncodingException e) {
e.printStackTrace();
return null;
}
}
public byte[] stringToByteArray(String s) {
try {
return s.getBytes(keyEncoding);
} catch (UnsupportedEncodingException e) {
e.printStackTrace();
return null;
}
}
/**
* Customized trove hashmap which does both: customized hash/equality
* functions, and also storing the values as a primitive array.
*/
static class TByteArrayDoubleHashMap extends TObjectDoubleHashMap {
public TByteArrayDoubleHashMap(int initialSize, float loadFactor,
double defaultValue) {
super(initialSize, loadFactor, defaultValue);
}
protected int hash(Object obj) {
return Arrays.hashCode((byte[])obj);
}
protected boolean equals(Object a, Object b) {
return b != null && b != REMOVED
&& Arrays.equals((byte[])a, (byte[])b);
}
// - ovrride this to prevent doubling on resize.
public double put(byte[] key, double value) {
int index = insertKey(key);
double previous = 0.0;
boolean isNewMapping = true;
if (index < 0) {
index = -index - 1;
previous = _values[index];
isNewMapping = false;
}
_values[index] = value;
if (isNewMapping) {
postInsertHook2(consumeFreeSlot);
}
return previous;
}
protected final void postInsertHook2(boolean usedFreeSlot) {
if (usedFreeSlot) {
_free--;
}
if (++_size > _maxSize || _free == 0) {
int newCapacity = _size > _maxSize ? gnu.trove.impl.PrimeFinder
.nextPrime((int)(capacity() * 1.2) + 10) : capacity();
if (newCapacity > 1000000) {
System.out.println("rehashing to size: " + newCapacity
+ " from " + capacity());
}
rehash(newCapacity);
computeMaxSize(capacity());
}
}
}
public int size() {
return map.size();
}
public boolean containsKey(Object key) {
if (key instanceof byte[]) {
return map.containsKey(key);
} else if (key instanceof String) {
return map.containsKey(stringToByteArray((String)key));
} else {
throw new IllegalArgumentException("class "
+ key.getClass().toString()
+ " is not valid for ByteArrayDoubleHashMap.containsKey");
}
}
public Set keySet() {
Set res = new HashSet();
for (byte[] b : map.keySet()) {
res.add(byteArrayToString(b));
}
return res;
}
public Set values() {
Set values = new HashSet();
for (Map.Entry e : this) {
values.add(e.getValue());
}
return values;
}
public boolean containsValue(Object d) {
return values().contains((Double)d);
}
public Set> entrySet() {
Set> entries = new HashSet>();
for (Map.Entry e : this) {
entries.add(e);
}
return entries;
}
public boolean isEmpty() {
return size() > 0;
}
public void clear() {
map.clear();
}
public Double remove(Object k) {
return removePrimitive((String)k);
}
public Double get(Object k) {
return getPrimitive((String)k);
}
public Double put(String key, Double value) {
return putPrimitive(key, value);
}
public void putAll(Map extends String, ? extends Double> m) {
for (Map.Entry extends String, ? extends Double> e : m.entrySet()) {
put((String)e.getKey(), (Double)e.getValue());
}
}
public double getPrimitive(byte[] key) {
return map.get(key);
}
public double getPrimitive(String key) {
return map.get(stringToByteArray(key));
}
public double putPrimitive(byte[] key, double value) {
return map.put(key, value);
}
public double putPrimitive(String key, double value) {
return map.put(stringToByteArray(key), value);
}
public double removePrimitive(byte[] key) {
return map.remove(key);
}
public double removePrimitive(String key) {
return map.remove(stringToByteArray(key));
}
public void transformValues(TDoubleFunction func) {
map.transformValues(func);
}
public TObjectDoubleIterator troveIterator() {
return map.iterator();
}
public Iterator> iterator() {
return new Iterator>() {
private TObjectDoubleIterator iter = troveIterator();
public boolean hasNext() {
return iter.hasNext();
}
public void remove() {
iter.remove();
}
public Map.Entry next() {
iter.advance();
return new AbstractMap.SimpleImmutableEntry(
byteArrayToString(iter.key()), iter.value());
}
};
}
// - java serialization
private void writeObject(ObjectOutputStream output) throws IOException {
output.writeObject(keyEncoding);
output.writeFloat(loadFactor);
output.writeDouble(defaultValue);
output.writeInt(map.size());
for (TObjectDoubleIterator it = map.iterator(); it.hasNext();) {
it.advance();
byte[] key = it.key();
output.writeInt(key.length);
for (int i = 0; i < key.length; i++) {
output.writeByte(key[i]);
}
output.writeDouble(it.value());
}
}
private void readObject(ObjectInputStream input) throws IOException,
ClassNotFoundException {
keyEncoding = (String)input.readObject();
loadFactor = input.readFloat();
defaultValue = input.readDouble();
int size = input.readInt();
map = new TByteArrayDoubleHashMap(size, loadFactor, defaultValue);
for (int i = 0; i < size; i++) {
int length = input.readInt();
byte[] key = new byte[length];
for (int j = 0; j < length; j++) {
key[j] = input.readByte();
}
double value = input.readDouble();
map.put(key, value);
}
}
// - kryo serialization for use in scalding.
public void write(Kryo kryo, Output output) {
output.writeString(keyEncoding);
output.writeFloat(loadFactor);
output.writeDouble(defaultValue);
output.writeInt(map.size());
for (TObjectDoubleIterator it = map.iterator(); it.hasNext();) {
it.advance();
byte[] key = it.key();
output.writeInt(key.length);
for (int i = 0; i < key.length; i++) {
output.writeByte(key[i]);
}
output.writeDouble(it.value());
}
}
public void read(Kryo kryo, Input input) {
keyEncoding = input.readString();
loadFactor = input.readFloat();
defaultValue = input.readDouble();
int size = input.readInt();
map = new TByteArrayDoubleHashMap(size, loadFactor, defaultValue);
for (int i = 0; i < size; i++) {
int length = input.readInt();
byte[] key = new byte[length];
for (int j = 0; j < length; j++) {
key[j] = input.readByte();
}
double value = input.readDouble();
map.put(key, value);
}
}
}
================================================
FILE: src/main/java/com/etsy/conjecture/data/ClusterLabel.java
================================================
package com.etsy.conjecture.data;
public class ClusterLabel extends Label{
private static final long serialVersionUID = 1L;
protected String label;
public ClusterLabel() {
this(null);
}
public ClusterLabel(String label) {
this.label = label;
}
public String getLabel() {
return this.label;
}
public void setLabel(String label) {
this.label = label;
}
public String toString() {
return label;
}
@Override
public int hashCode() {
final int prime = 31;
int result = 1;
result = prime * result + ((label == null) ? 0 : label.hashCode());
return result;
}
@Override
public boolean equals(Object obj) {
if (this == obj)
return true;
if (obj == null)
return false;
if (getClass() != obj.getClass())
return false;
ClusterLabel other = (ClusterLabel) obj;
if (label == null) {
if (other.label != null)
return false;
} else if (!label.equals(other.label))
return false;
return true;
}
}
================================================
FILE: src/main/java/com/etsy/conjecture/data/ClusterPrediction.java
================================================
package com.etsy.conjecture.data;
import java.util.Map;
import com.google.common.collect.Maps;
/**
* Representing a probability of membership in each cluster
*/
public class ClusterPrediction extends ClusterLabel{
private static final long serialVersionUID = -1L;
/**
* Cluster membership probabilities
*/
private Map clusterProbs;
public ClusterPrediction(Map clusterProbs) {
this.clusterProbs = Maps.newHashMap(clusterProbs);
boolean first = true;
double maxProb = 0;
String maxCategory = null;
for (String key : clusterProbs.keySet()) {
if(first || clusterProbs.get(key) > maxProb) {
maxProb = clusterProbs.get(key);
maxCategory = key;
first = false;
}
}
setLabel(maxCategory);
}
public Map getMap() {
return clusterProbs;
}
}
================================================
FILE: src/main/java/com/etsy/conjecture/data/Instance.java
================================================
package com.etsy.conjecture.data;
//TODO: reset methods for string adders
//TODO: for instance, vector subtraction?
public class Instance extends AbstractInstance {
public Instance() {
super();
}
public Instance(StringKeyedVector vec) {
super(vec);
}
}
================================================
FILE: src/main/java/com/etsy/conjecture/data/InstanceFactory.java
================================================
package com.etsy.conjecture.data;
public class InstanceFactory {
private InstanceFactory() {
};
public static Instance buildInstance() {
return new Instance();
}
public static Instance copyInstance(Instance inst) {
return new Instance(inst.getVector());
}
public static BinaryLabeledInstance toBinaryLabeledInstance(double label,
Instance instance) {
return new BinaryLabeledInstance(label, instance.getVector());
}
public static BinaryLabeledInstance toBinaryLabeledInstance(
BinaryLabel label, Instance instance) {
return new BinaryLabeledInstance(label, instance.getVector());
}
public static RealValueLabeledInstance toRealValueLabeledInstance(
double label, Instance instance) {
return new RealValueLabeledInstance(label, instance.getVector());
}
public static RealValueLabeledInstance toRealValueLabeledInstance(
RealValuedLabel label, Instance instance) {
return new RealValueLabeledInstance(label, instance.getVector());
}
}
================================================
FILE: src/main/java/com/etsy/conjecture/data/InstanceInterface.java
================================================
package com.etsy.conjecture.data;
import java.util.Collection;
import java.util.List;
import java.util.Map;
public interface InstanceInterface> {
public abstract String getId();
public abstract T setId(String id);
public abstract T addTerm(String term);
public abstract T addTerm(String term, double featureWeight);
public abstract T addTermWithNamespace(String term, String namespace);
public abstract T addTermWithNamespace(String term, String namespace,
double featureWeight);
public abstract T addTerms(Collection terms, double featureWeight);
public abstract T addTerms(Collection terms);
public abstract T addTermsWithNamespace(Collection terms,
String namespace, double featureWeight);
public abstract T addTermsWithNamespace(Collection terms,
String namespace);
public abstract T addTerms(String[] terms, double featureWeight);
public abstract T addTerms(String[] terms);
public abstract T addTermsWithNamespace(String[] terms, String namespace,
double featureWeight);
public abstract T addTermsWithNamespace(String[] terms, String namespace);
public abstract T addTermsWithWeights(Map termsWithWeights);
public abstract T addTermsWithWeightsWithNamespace(
Map termsWithWeights, String namespace);
public abstract T addNumericArrayWithNamespace(double[] array,
String namespace);
public abstract T addNumericArray(double[] array);
public abstract T addNumericArrayWithNamespace(Double[] array,
String namespace);
public abstract T addNumericArray(Double[] array);
public abstract T addNumericArrayWithNamespace(List values,
String namespace);
public abstract T addNumericArray(List values);
public abstract T setNumericArrayWithNamespace(double[] array,
String namespace);
public abstract T setNumericArray(double[] array);
public abstract T setNumericArrayWithNamespace(Double[] array,
String namespace);
public abstract T setNumericArray(Double[] array);
public abstract T setNumericArrayWithNamespace(List values,
String namespace);
public abstract T setNumericArray(List values);
public abstract T addIdField(long id, double featureWeight);
public abstract T addIdField(long id);
public abstract T addIdFieldWithNamespace(long id, double featureWeight,
String namespace);
public abstract T addIdFieldWithNamespace(long id, String namespace);
public abstract T addIdField(int id, double featureWeight);
public abstract T addIdField(int id);
public abstract T addIdFieldWithNamespace(int id, double featureWeight,
String namespace);
public abstract T addIdFieldWithNamespace(int id, String namespace);
public abstract T addIds(long[] ids, double featureWeight);
public abstract T addIds(long[] ids);
public abstract T addIds(int[] ids, double featureWeight);
public abstract T addIds(int[] ids);
public abstract T addIds(Collection ids, double featureWeight);
public abstract T addIds(Collection ids);
public abstract T addIdsWithNamespace(long[] ids, double featureWeight,
String namespace);
public abstract T addIdsWithNamespace(long[] ids, String namespace);
public abstract T addIdsWithNamespace(int[] ids, double featureWeight,
String namespace);
public abstract T addIdsWithNamespace(int[] ids, String namespace);
public abstract T addIdsWithNamespace(Collection ids,
double featureWeight, String namespace);
public abstract T addIdsWithNamespace(Collection ids, String namespace);
public abstract T setIdField(long id, double featureWeight);
public abstract T setIdField(long id);
public abstract T setIdFieldWithNamespace(long id, double featureWeight,
String namespace);
public abstract T setIdFieldWithNamespace(long id, String namespace);
public abstract T setIdField(int id, double featureWeight);
public abstract T setIdField(int id);
public abstract T setIdFieldWithNamespace(int id, double featureWeight,
String namespace);
public abstract T setIdFieldWithNamespace(int id, String namespace);
public abstract T setIds(long[] ids, double featureWeight);
public abstract T setIds(long[] ids);
public abstract T setIds(int[] ids, double featureWeight);
public abstract T setIds(int[] ids);
public abstract T setIds(Collection ids, double featureWeight);
public abstract T setIds(Collection ids);
public abstract T setIdsWithNamespace(long[] ids, double featureWeight,
String namespace);
public abstract T setIdsWithNamespace(long[] ids, String namespace);
public abstract T setIdsWithNamespace(int[] ids, double featureWeight,
String namespace);
public abstract T setIdsWithNamespace(int[] ids, String namespace);
public abstract T setIdsWithNamespace(Collection ids,
double featureWeight, String namespace);
public abstract T setIdsWithNamespace(Collection ids, String namespace);
}
================================================
FILE: src/main/java/com/etsy/conjecture/data/Label.java
================================================
package com.etsy.conjecture.data;
public class Label implements java.io.Serializable {
private static final long serialVersionUID = 1L;
public Label() {
}
}
================================================
FILE: src/main/java/com/etsy/conjecture/data/LabeledInstance.java
================================================
package com.etsy.conjecture.data;
public interface LabeledInstance {
public L getLabel();
public StringKeyedVector getVector();
public double getWeight();
}
================================================
FILE: src/main/java/com/etsy/conjecture/data/LazyVector.java
================================================
package com.etsy.conjecture.data;
import gnu.trove.function.TDoubleFunction;
import gnu.trove.iterator.TObjectDoubleIterator;
import java.io.IOException;
import java.io.ObjectInputStream;
import java.io.ObjectOutputStream;
import java.io.Serializable;
import java.util.ArrayList;
import java.util.Iterator;
import java.util.Map;
import java.util.Set;
import com.esotericsoftware.kryo.Kryo;
import com.esotericsoftware.kryo.KryoSerializable;
import com.esotericsoftware.kryo.io.Input;
import com.esotericsoftware.kryo.io.Output;
import com.etsy.conjecture.Utilities;
public class LazyVector extends StringKeyedVector implements Serializable,
KryoSerializable {
private static final long serialVersionUID = -7070522686694887436L;
protected transient ByteArrayDoubleHashMap iterations;
protected long iteration = 0;
protected UpdateFunction updater;
/**
* The function used to update the parameters during the lazy update
*/
public static interface UpdateFunction extends Serializable {
public double lazyUpdate(String key, double param, long startIteration,
long endIteration);
}
public LazyVector() {
this(new UpdateFunction() {
private static final long serialVersionUID = 1740773207106961880L;
public double lazyUpdate(String key, double p, long a, long b) {
return p;
}
});
}
public LazyVector(UpdateFunction uf) {
this(10, uf);
}
public LazyVector(int initialCapacity, UpdateFunction uf) {
super(initialCapacity);
iterations = new ByteArrayDoubleHashMap(initialCapacity, LOAD_FACTOR,
FEATURE_ENCODING, 0.0);
updater = uf;
}
public LazyVector(StringKeyedVector skv, UpdateFunction uf) {
if (skv instanceof LazyVector) {
((LazyVector)skv).delazify();
}
this.vector = skv.vector;
iterations = new ByteArrayDoubleHashMap(skv.size(), LOAD_FACTOR,
FEATURE_ENCODING, 0.0);
updater = uf;
}
public LazyVector(ByteArrayDoubleHashMap map, UpdateFunction uf) {
super(map);
iterations = new ByteArrayDoubleHashMap(10, LOAD_FACTOR,
FEATURE_ENCODING, 0.0);
updater = uf;
}
public LazyVector(Map jmap, UpdateFunction uf) {
super(jmap);
iterations = new ByteArrayDoubleHashMap(10, LOAD_FACTOR,
FEATURE_ENCODING, 0.0);
updater = uf;
}
public void incrementIteration() {
iteration++;
}
public void delazify() {
for (TObjectDoubleIterator it = vector.troveIterator(); it
.hasNext();) {
it.advance();
long startIter = (long)iterations.getPrimitive(it.key()); // defaults
// to 0.0
if (startIter < iteration) {
it.setValue(updater.lazyUpdate(it.key().toString(), it.value(), startIter, iteration));
iterations.putPrimitive(it.key(), (double)iteration);
}
}
removeZeroCoordinates();
}
public double delazifyCoordinate(String key) {
return delazifyCoordinate(vector.stringToByteArray(key));
}
public double delazifyCoordinate(byte[] key) {
if (vector.containsKey(key)) {
long oldIteration = (long)iterations.getPrimitive(key);
double initial = vector.getPrimitive(key);
if (oldIteration < iteration) {
double updated = updater.lazyUpdate(key.toString(), initial, oldIteration,
iteration);
if (Utilities.floatingPointEquals(updated, 0.0d)) {
vector.removePrimitive(key);
iterations.removePrimitive(key);
} else {
iterations.putPrimitive(key, (double)iteration);
vector.putPrimitive(key, updated);
}
return updated;
} else {
return initial;
}
}
return 0.0;
}
public void skipToIteration(long iter) {
delazify();
iteration = iter;
for (TObjectDoubleIterator it = iterations.troveIterator(); it
.hasNext();) {
it.advance();
it.setValue((double)iter);
}
}
/**
* disregards prior value at a particular key, replacing with the specified
* value.
*/
public double setCoordinate(String key, double value) {
if (Utilities.floatingPointEquals(value, 0d)) {
return deleteCoordinate(key);
} else if (!freezeKeySet) {
vector.putPrimitive(key, value);
iterations.putPrimitive(key, (double)iteration);
}
return 0d;
}
/**
* remove a coordinate from the vector (same as setting it to 0).
*/
public double deleteCoordinate(String key) {
if (vector.containsKey(key) && !freezeKeySet) {
iterations.removePrimitive(key);
return vector.removePrimitive(key);
} else {
return 0d;
}
}
public Map getMap() {
return vector;
}
protected double addToCoordinateInternal(byte[] bkey, double value) {
delazifyCoordinate(bkey);
if (vector.containsKey(bkey)) {
double updated = vector.getPrimitive(bkey) + value;
if (Utilities.floatingPointEquals(updated, 0.0d)) {
iterations.removePrimitive(bkey);
return vector.removePrimitive(bkey);
} else {
iterations.putPrimitive(bkey, (double)iteration);
return vector.putPrimitive(bkey, updated);
}
} else if (!freezeKeySet && !Utilities.floatingPointEquals(value, 0.0d)) {
vector.putPrimitive(bkey, value);
iterations.putPrimitive(bkey, (double)iteration);
}
return 0d;
}
/**
* return the value of a coordinate.
*/
public double getCoordinate(String key) {
delazifyCoordinate(key);
return vector.getPrimitive(key);
}
/**
* the dimension of the vector.
*/
public int size() {
delazify();
return vector.size();
}
/**
* whether this vector has a non-zero value for a coordinate.
*/
public boolean containsKey(String key) {
delazify();
return vector.containsKey(key);
}
/**
* whether this vector has a non-zero value for a coordinate.
*/
public boolean contains(String key) {
return containsKey(key);
}
/**
* the set of non-zero coordinate names.
*/
public Set keySet() {
delazify();
return vector.keySet();
}
/**
* the set of values in the map.
*/
public Set values() {
delazify();
return vector.values();
}
/**
* Apply an arbitrary scalar function to the values.
*/
public void transformValues(TDoubleFunction func) {
delazify();
vector.transformValues(func);
}
/**
* Remove zeros that may have appeared as a result of a transform
*/
public void removeZeroCoordinates() {
for (TObjectDoubleIterator it = vector.troveIterator(); it
.hasNext();) {
it.advance();
if (Utilities.floatingPointEquals(it.value(), 0d)) {
iterations.removePrimitive(it.key());
it.remove();
}
}
}
/**
* compute the inner product between this and vec.
*/
public double dot(StringKeyedVector skv) {
if (skv instanceof LazyVector) {
return dotWithLazy((LazyVector)skv);
} else {
return dotWithSKV(skv);
}
}
protected double dotWithSKV(StringKeyedVector vec) {
// dont figure out which ones bigger etc, since delazifying this to get
// the size is too slow.
double res = 0.0;
for (TObjectDoubleIterator it = vec.vector.troveIterator(); it
.hasNext();) {
it.advance();
res += it.value() * delazifyCoordinate(it.key());
}
return res;
}
protected double dotWithLazy(LazyVector vec) {
ByteArrayDoubleHashMap vec_small = this.size() > vec.size() ? vec.vector
: this.vector;
ByteArrayDoubleHashMap vec_big = this.size() > vec.size() ? this.vector
: vec.vector;
ArrayList commonCoordinates = new ArrayList(); // prevent
// modification
// during
// iteration.
double res = 0.0;
for (TObjectDoubleIterator it = vec_small.troveIterator(); it
.hasNext();) {
it.advance();
if (vec_big.containsKey(it.key())) {
commonCoordinates.add(it.key());
}
}
for (byte[] key : commonCoordinates) {
delazifyCoordinate(key);
vec.delazifyCoordinate(key);
res += vec_small.getPrimitive(key) * vec_big.getPrimitive(key);
}
return res;
}
/**
* compute the LP norm for given p < infinity.
*/
public double LPNorm(double p) {
delazify();
return super.LPNorm(p);
}
/**
* immutable access the underlying hash map.
*/
public Iterator> iterator() {
delazify();
return vector.iterator();
}
public String toString() {
delazify();
return super.toString();
}
private Object writeReplace() throws java.io.ObjectStreamException {
delazify();
return this;
}
// - java serialization
private void writeObject(ObjectOutputStream output) throws IOException {
output.writeLong(iteration);
output.writeObject(vector);
output.writeObject(updater);
output.writeBoolean(freezeKeySet);
}
private void readObject(ObjectInputStream input) throws IOException,
ClassNotFoundException {
iteration = input.readLong();
vector = (ByteArrayDoubleHashMap)input.readObject();
updater = (UpdateFunction)input.readObject();
freezeKeySet = input.readBoolean();
// set up iteration info,
iterations = new ByteArrayDoubleHashMap(10, LOAD_FACTOR,
(double)iteration);
}
// - kryo serialization for use in scalding.
public void write(Kryo kryo, Output output) {
delazify();
output.writeLong(iteration);
kryo.writeObject(output, vector);
kryo.writeClassAndObject(output, updater);
output.writeBoolean(freezeKeySet);
}
public void read(Kryo kryo, Input input) {
iteration = input.readLong();
vector = kryo.readObject(input, ByteArrayDoubleHashMap.class);
updater = (UpdateFunction)kryo.readClassAndObject(input);
freezeKeySet = input.readBoolean();
// set up iteration info,
iterations = new ByteArrayDoubleHashMap(10, LOAD_FACTOR,
(double)iteration);
}
}
================================================
FILE: src/main/java/com/etsy/conjecture/data/MulticlassLabel.java
================================================
package com.etsy.conjecture.data;
/**
* representing a 100% probability of membership in a particular class
*/
public class MulticlassLabel extends Label {
private static final long serialVersionUID = 1L;
protected String label;
public MulticlassLabel() {
this(null);
}
public MulticlassLabel(String label) {
this.label = label;
}
public String getLabel() {
return this.label;
}
public void setLabel(String label) {
this.label = label;
}
public String toString() {
return label;
}
public BinaryLabel toBinaryLabel(String className) {
return new BinaryLabel(className.equals(label) ? 1.0 : 0.0);
}
@Override
public int hashCode() {
final int prime = 31;
int result = 1;
result = prime * result + ((label == null) ? 0 : label.hashCode());
return result;
}
@Override
public boolean equals(Object obj) {
if (this == obj)
return true;
if (obj == null)
return false;
if (getClass() != obj.getClass())
return false;
MulticlassLabel other = (MulticlassLabel)obj;
if (label == null) {
if (other.label != null)
return false;
} else if (!label.equals(other.label))
return false;
return true;
}
}
================================================
FILE: src/main/java/com/etsy/conjecture/data/MulticlassLabeledInstance.java
================================================
package com.etsy.conjecture.data;
import java.util.Map;
public class MulticlassLabeledInstance extends
AbstractInstance implements
LabeledInstance {
protected MulticlassLabel label;
public MulticlassLabel getLabel() {
return label;
}
public MulticlassLabeledInstance(String label) {
this(new MulticlassLabel(label), 1.0);
}
public MulticlassLabeledInstance(String label, double weight) {
this(new MulticlassLabel(label), weight);
}
public MulticlassLabeledInstance(String label, Map instance) {
this(new MulticlassLabel(label), instance, 1.0);
}
public MulticlassLabeledInstance(String label,
Map instance, double weight) {
this(new MulticlassLabel(label), instance, weight);
}
public MulticlassLabeledInstance(String label, StringKeyedVector vec) {
this(new MulticlassLabel(label), vec.getMap(), 1.0);
}
public MulticlassLabeledInstance(String label, StringKeyedVector vec,
double weight) {
this(new MulticlassLabel(label), vec.getMap(), weight);
}
public MulticlassLabeledInstance(MulticlassLabel label) {
this(label, 1.0);
}
public MulticlassLabeledInstance(MulticlassLabel label, double weight) {
super(weight);
this.label = label;
}
public MulticlassLabeledInstance(MulticlassLabel label,
Map instance) {
this(label, instance, 1.0);
}
public MulticlassLabeledInstance(MulticlassLabel label,
Map instance, double weight) {
super(instance, weight);
this.label = label;
}
public MulticlassLabeledInstance(MulticlassLabel label,
StringKeyedVector vec) {
this(label, vec.getMap(), 1.0);
}
public MulticlassLabeledInstance(MulticlassLabel label,
StringKeyedVector vec, double weight) {
this(label, vec.getMap(), weight);
}
public BinaryLabeledInstance toBinaryInstance(String category) {
double tmpLabel = 0d;
if (category.equals(this.label.getLabel())) {
tmpLabel = 1d;
}
return new BinaryLabeledInstance(tmpLabel, getVector());
}
}
================================================
FILE: src/main/java/com/etsy/conjecture/data/MulticlassPrediction.java
================================================
package com.etsy.conjecture.data;
import java.util.Map;
import com.google.common.collect.Maps;
/**
* representing a probability of membership in each class
*/
public class MulticlassPrediction extends MulticlassLabel {
private static final long serialVersionUID = -1L;
/**
* class membership probabilities
*/
private Map classProbs;
public MulticlassPrediction(Map classProbs) {
this.classProbs = Maps.newHashMap(classProbs);
boolean first = true;
double maxProb = 0;
String maxCategory = null;
for (String key : classProbs.keySet()) {
if (first || classProbs.get(key) > maxProb) {
maxProb = classProbs.get(key);
maxCategory = key;
first = false;
}
}
setLabel(maxCategory);
}
public Double getProb(String category) {
return classProbs.get(category);
}
public Double getProbOrElse(String category, Double def) {
if (classProbs.containsKey(category)) {
return classProbs.get(category);
} else {
return def;
}
}
public Map getMap() {
return classProbs;
}
}
================================================
FILE: src/main/java/com/etsy/conjecture/data/RealValueLabeledInstance.java
================================================
package com.etsy.conjecture.data;
import java.util.Map;
public class RealValueLabeledInstance extends
AbstractInstance implements
LabeledInstance {
private final RealValuedLabel label;
public RealValuedLabel getLabel() {
return label;
}
public RealValueLabeledInstance() {
this(0.0);
}
public RealValueLabeledInstance(RealValuedLabel label) {
this(label, 1.0);
}
public RealValueLabeledInstance(RealValuedLabel label, double weight) {
super(weight);
this.label = label;
}
public RealValueLabeledInstance(double label) {
this(new RealValuedLabel(label), 1.0);
}
public RealValueLabeledInstance(double label, double weight) {
this(new RealValuedLabel(label), weight);
}
public RealValueLabeledInstance(double label, Map instance) {
this(new RealValuedLabel(label), instance, 1.0);
}
public RealValueLabeledInstance(double label, Map instance,
double weight) {
this(new RealValuedLabel(label), instance, weight);
}
public RealValueLabeledInstance(double label, StringKeyedVector vec) {
this(new RealValuedLabel(label), vec.getMap(), 1.0);
}
public RealValueLabeledInstance(double label, StringKeyedVector vec,
double weight) {
this(new RealValuedLabel(label), vec.getMap(), weight);
}
public RealValueLabeledInstance(RealValuedLabel label,
Map instance) {
this(label, instance, 1.0);
}
public RealValueLabeledInstance(RealValuedLabel label,
Map instance, double weight) {
super(instance, weight);
this.label = label;
}
public RealValueLabeledInstance(RealValuedLabel label, StringKeyedVector vec) {
this(label, vec, 1.0);
}
public RealValueLabeledInstance(RealValuedLabel label,
StringKeyedVector vec, double weight) {
super(vec.getMap(), weight);
this.label = label;
}
}
================================================
FILE: src/main/java/com/etsy/conjecture/data/RealValuedLabel.java
================================================
package com.etsy.conjecture.data;
public class RealValuedLabel extends Label {
protected final Double value;
private static final long serialVersionUID = -1L;
public RealValuedLabel(double value) {
this.value = new Double(value);
}
public Double getValue() {
return this.value;
}
@Override
public String toString() {
return value + "";
}
}
================================================
FILE: src/main/java/com/etsy/conjecture/data/Recommendation.java
================================================
package com.etsy.conjecture.data;
import java.io.Serializable;
public class Recommendation implements Serializable {
private static final long serialVersionUID = 1L;
public final double score;
public final String id;
public Recommendation(String id, double score) {
this.id = id;
this.score = score;
}
}
================================================
FILE: src/main/java/com/etsy/conjecture/data/StringKeyedVector.java
================================================
package com.etsy.conjecture.data;
import gnu.trove.function.TDoubleFunction;
import gnu.trove.iterator.TObjectDoubleIterator;
import java.io.Serializable;
import java.util.Iterator;
import java.util.Map;
import java.util.Set;
import com.etsy.conjecture.Utilities;
import com.google.gson.Gson;
public class StringKeyedVector implements Serializable,
Iterable> {
private static final long serialVersionUID = -7070522686694887436L;
// - represent the sparse vector by a mapping of coordinate name strings
// (feature names)
// to doubles.
protected ByteArrayDoubleHashMap vector;
// - whether to permit the addition of more features to this vector.
protected boolean freezeKeySet = false;
// - the load factor for the underlying hashmap.
public static final float LOAD_FACTOR = 0.9f;
public static final String FEATURE_ENCODING = "ASCII";
public StringKeyedVector() {
this(10);
}
public StringKeyedVector(int initialCapacity) {
vector = new ByteArrayDoubleHashMap(initialCapacity, LOAD_FACTOR,
FEATURE_ENCODING, 0.0);
}
public StringKeyedVector(StringKeyedVector skv) {
this(skv.size());
add(skv);
}
public StringKeyedVector(Map jmap) {
vector = new ByteArrayDoubleHashMap(jmap.size(), LOAD_FACTOR,
FEATURE_ENCODING, 0.0);
vector.putAll(jmap);
}
/**
* returns whether the key set is frozen (true means that further dimensions
* cannot be added to this vector).
*/
public boolean getFreezeKeySet() {
return freezeKeySet;
}
/**
* sets whether the key set is frozen (true means that further dimensions
* cannot be added to this vector).
*/
public void setFreezeKeySet(boolean freeze) {
freezeKeySet = freeze;
}
/**
* disregards prior value at a particular key, replacing with the specified
* value.
*/
public double setCoordinate(String key, double value) {
if (Utilities.floatingPointEquals(value, 0d)) {
return deleteCoordinate(key);
} else if (!freezeKeySet) {
vector.putPrimitive(key, value);
}
return 0d;
}
/**
* remove a coordinate from the vector (same as setting it to 0).
*/
public double deleteCoordinate(String key) {
if (vector.containsKey(key) && !freezeKeySet) {
return vector.removePrimitive(key);
} else {
return 0d;
}
}
public Map getMap() {
return vector;
}
/**
* add to a specified coordinate (treating it as 0 if it was not present).
*/
public double addToCoordinate(String key, double value) {
byte[] bkey = vector.stringToByteArray(key);
return addToCoordinateInternal(bkey, value);
}
protected double addToCoordinateInternal(byte[] bkey, double value) {
if (vector.containsKey(bkey)) {
double updated = vector.getPrimitive(bkey) + value;
if (Utilities.floatingPointEquals(updated, 0.0d)) {
return vector.removePrimitive(bkey);
} else {
return vector.putPrimitive(bkey, updated);
}
} else if (!freezeKeySet && !Utilities.floatingPointEquals(value, 0.0d)) {
vector.putPrimitive(bkey, value);
}
return 0d;
}
/**
* return the value of a coordinate.
*/
public double getCoordinate(String key) {
return vector.getPrimitive(key);
}
/**
* add a multiple of vec to this.
*/
public void addScaled(StringKeyedVector vec, double scale) {
if (vec instanceof LazyVector) {
((LazyVector)vec).delazify();
}
for (TObjectDoubleIterator it = vec.vector.troveIterator(); it
.hasNext();) {
it.advance();
addToCoordinateInternal(it.key(), scale * it.value());
}
}
public StringKeyedVector multiplyPointwise(StringKeyedVector vec) {
StringKeyedVector res = new StringKeyedVector();
if (vec instanceof LazyVector) {
((LazyVector)vec).delazify();
}
for (TObjectDoubleIterator it = vec.vector.troveIterator(); it
.hasNext();) {
it.advance();
res.vector.putPrimitive(it.key(), vector.getPrimitive(it.key())
* it.value());
}
return res;
}
public StringKeyedVector projectOntoNonZeroCoordinates(StringKeyedVector vec) {
StringKeyedVector res = new StringKeyedVector();
if (vec instanceof LazyVector) {
((LazyVector)vec).delazify();
}
for (TObjectDoubleIterator it = vec.vector.troveIterator(); it
.hasNext();) {
it.advance();
res.addToCoordinateInternal(it.key(), vector.getPrimitive(it.key()));
}
return res;
}
/**
* the dimension of the vector.
*/
public int size() {
return vector.size();
}
/**
* whether this vector has a non-zero value for a coordinate.
*/
public boolean containsKey(String key) {
return vector.containsKey(key);
}
/**
* whether this vector has a non-zero value for a coordinate.
*/
public boolean contains(String key) {
return containsKey(key);
}
/**
* the set of non-zero coordinate names.
*/
public Set keySet() {
return vector.keySet();
}
/**
* the set of values in the map.
*/
public Set values() {
return vector.values();
}
/**
* add vec to this
*/
public void add(StringKeyedVector vec) {
addScaled(vec, 1.0);
}
/**
* subtract vec from this.
*/
public void sub(StringKeyedVector vec) {
addScaled(vec, -1.0);
}
/**
* multiply this vector by a scalar.
*/
public void mul(final double a) {
transformValues(new TDoubleFunction() {
public double execute(double b) {
return a * b;
}
});
}
/**
* Apply an arbitrary scalar function to the values.
*/
public void transformValues(TDoubleFunction func) {
vector.transformValues(func);
}
/**
* Remove zeros that may have appeared as a result of a transform
*/
public void removeZeroCoordinates() {
@SuppressWarnings("unused")
int i = 0;
for (TObjectDoubleIterator it = vector.troveIterator(); it
.hasNext();) {
it.advance();
if (Utilities.floatingPointEquals(it.value(), 0d)) {
i++;
it.remove();
}
}
}
/**
* compute the inner product between this and vec.
*/
public double dot(StringKeyedVector vec) {
if (vec instanceof LazyVector) {
return vec.dot(this);
}
ByteArrayDoubleHashMap vec_small = this.size() > vec.size() ? vec.vector
: this.vector;
ByteArrayDoubleHashMap vec_big = this.size() > vec.size() ? this.vector
: vec.vector;
double res = 0.0;
for (TObjectDoubleIterator it = vec_small.troveIterator(); it
.hasNext();) {
it.advance();
if (vec_big.containsKey(it.key())) {
res += it.value() * vec_big.getPrimitive(it.key());
}
}
return res;
}
/**
* compute the LP norm for given p < infinity.
*/
public double LPNorm(double p) {
double tot = 0d;
for (double v : vector.values()) {
tot += Math.pow(Math.abs(v), p);
}
return Math.pow(tot, 1d / p);
}
/**
* Find the max value.
*/
public double max() {
double max = 0.0;
for (double v : vector.values()) {
if (v > max) {
max = v;
}
}
return max;
}
/**
* immutable access the underlying hash map.
*/
public Iterator> iterator() {
return vector.iterator();
}
public String toString() {
Gson gson = new Gson();
return gson.toJson(vector);
}
/**
* performs a deep copy of a stringkeyedvector
*
*/
public StringKeyedVector copy() {
StringKeyedVector out = new StringKeyedVector(this.size());
Iterator> it = this.iterator();
while (it.hasNext()) {
Map.Entry entry = it.next();
String key = entry.getKey();
Double value = entry.getValue();
out.setCoordinate(key, value);
}
return out;
}
}
================================================
FILE: src/main/java/com/etsy/conjecture/evaluation/BinaryModelEvaluation.java
================================================
package com.etsy.conjecture.evaluation;
import java.io.Serializable;
import java.util.HashMap;
import java.util.Map;
import java.util.SortedMap;
import java.util.TreeMap;
import com.etsy.conjecture.data.BinaryLabel;
import com.etsy.conjecture.PrimitivePair;
/**
* a basic container for evaluations TODO: add getters for individual metrics
*/
public class BinaryModelEvaluation implements ModelEvaluation,
Serializable {
private static final long serialVersionUID = 1L;
private final ReceiverOperatingCharacteristic ROC;
private final ConfusionMatrix conf;
public BinaryModelEvaluation() {
ROC = new ReceiverOperatingCharacteristic();
conf = new ConfusionMatrix(2);
}
public void merge(ModelEvaluation other) {
BinaryModelEvaluation tempOther = (BinaryModelEvaluation) other;
ROC.add(tempOther.ROC);
conf.add(tempOther.conf);
}
public void add(BinaryLabel real, BinaryLabel pred) {
add(real.getValue(), pred.getValue());
}
public void add(double label, double prediction) {
ROC.add(label, prediction);
conf.addHard((int)label, prediction);
}
public void add(PrimitivePair labelPrediction) {
ROC.add(labelPrediction);
conf.addHard((int)labelPrediction.first, labelPrediction.second);
}
public double computeAUC() {
return ROC.binaryAUC();
}
public double computeBrier() {
return ROC.brierScore();
}
public double computeAccy() {
return conf.computeAccuracy();
}
public double computeAccy(int dim) {
return conf.computeAccuracy(dim);
}
public double computeFmeasure() {
return conf.computeAverageFmeasure();
}
public double computeFmeasure(int dim) {
return conf.computeFmeasure(dim);
}
public double computePrecision() {
return conf.computeAveragePrecision();
}
public double computePrecision(int dim) {
return conf.computePrecision(dim);
}
public double computeRecall() {
return conf.computeAverageRecall();
}
public double computeRecall(int dim) {
return conf.computeRecall(dim);
}
public Map getStatistics() {
SortedMap m = new TreeMap();
m.put("Brier", computeBrier());
m.put("Acc (avg)", computeAccy());
m.put("F1 (avg)", computeFmeasure());
m.put("Prc (avg)", computePrecision());
m.put("Rec (avg)", computeRecall());
m.put("0-class Acc", computeAccy(0));
m.put("0-class F1", computeFmeasure(0));
m.put("0-class Prc", computePrecision(0));
m.put("0-class Rec", computeRecall(0));
m.put("1-class Acc", computeAccy(1));
m.put("1-class F1", computeFmeasure(1));
m.put("1-class Prc", computePrecision(1));
m.put("1-class Rec", computeRecall(1));
m.put("1-class AUC", computeAUC());
return m;
}
public Map getObjects() {
Map m = new HashMap();
m.put("conf", conf.toString());
return m;
}
}
================================================
FILE: src/main/java/com/etsy/conjecture/evaluation/ConfusionMatrix.java
================================================
package com.etsy.conjecture.evaluation;
import java.io.Serializable;
import java.util.Collection;
import com.etsy.conjecture.PrimitivePair;
import static com.google.common.base.Preconditions.checkArgument;
/**
* class representing a confusion matrix for representing misclassification
* errors.
* {@link Confusion Matrix }
*
* @author jattenberg
*/
public class ConfusionMatrix implements Serializable {
private static final long serialVersionUID = 1L;
/**
* The data structure representing the confusion matrix. rows correspond to
* labels, columns to predictions
*/
private double[][] confMatrix;
/** The num_classes represented in the confusion matrix */
private final int numClasses;
/** The number of label / prediction pairs observed */
double obs;
/**
* Instantiates a new confusion matrix.
*
* @param classes
* the number of target classes in the problem being considered
*/
public ConfusionMatrix(int classes) {
obs = 0;
this.numClasses = classes;
this.confMatrix = new double[numClasses][numClasses];
}
public void add(ConfusionMatrix m) {
obs += m.obs;
for (int i = 0; i < numClasses; i++) {
for (int j = 0; j < numClasses; j++) {
confMatrix[i][j] += m.confMatrix[i][j];
}
}
}
/**
* Instantiates a new confusion matrix and adds some initial data
*
* @param classes
* - the number of target classes in the problem being considered
* @param labelsAndPredictions
* the labels and predictions
*/
public ConfusionMatrix(int classes,
Collection labelsAndPredictions) {
this(classes);
for (PrimitivePair p : labelsAndPredictions)
addInfo(p.first, p.second);
}
/**
* Instantiates a new confusion matrix and adds some initial data
*
* @param classes
* - the number of target classes in the problem being considered
* @param labelsAndPredictions
* the labels and predictions
*/
public ConfusionMatrix(int classes, PrimitivePair[] labelsAndPredictions) {
this(classes);
for (PrimitivePair p : labelsAndPredictions)
addInfo(p.first, p.second);
}
/**
* Instantiates a new confusion matrix and adds some initial data
*
* @param classes
* - the number of target classes in the problem being considered
* @param labelsAndPredictions
* the labels and predictions
*/
public ConfusionMatrix(int classes, double[] labels, double[] predictions) {
this(classes);
checkArgument(
labels.length == predictions.length,
"labels and predictions must be of the same length! (%s vs %s)",
labels.length, predictions.length);
for (int i = 0; i < labels.length; i++) {
addInfo(labels[i], predictions[i]);
}
}
/**
* Adds a label / prediction pair to the confusion matrix
*
* @param label
* the index of the actual class
* @param guess
* the index of the predicted class
*/
public void addInfo(int label, int guess) {
obs++;
this.confMatrix[label][guess]++;
}
/**
* Adds a label / prediction pair to the confusion matrix with soft labels
*
* @param label
* the index of the actual class
* @param guess
* the predicted distribution over classes.
*/
public void addInfo(int label, double[] guess) {
addInfo(label, guess, 1);
}
/**
* Adds a label / prediction pair to the confusion matrix with soft labels
*
* @param label
* the index of the actual class
* @param guess
* the predicted distribution over classes.
* @param freq
* the number of times to consider the input label / prediction
* pair
*/
public void addInfo(int label, double[] guess, double freq) {
checkArgument(
guess.length == numClasses,
"input lenght (%d) must match num classes in confusion matrix (%d) ",
guess.length, numClasses);
obs += freq;
for (int i = 0; i < numClasses; i++) {
confMatrix[label][i] += freq * guess[i];
}
}
/**
* Adds a label / prediction pair to the confusion matrix with soft labels
* note, only applicable for binary classification (2 class) problems
*
* @param label
* the actual probability of membership in the positive class
* @param prediction
* the predicted probability of membership in the positive class
*/
public void addInfo(double label, double prediction) {
checkArgument(
2 == numClasses,
"num classes in confusion matrix (%d) must be 2 for this method",
numClasses);
addInfo(new double[] { 1. - label, label }, new double[] {
1. - prediction, prediction });
}
/**
* Adds a label / prediction pair to the confusion matrix with soft labels
*
* @param softlabels
* actual distribution of target class memberships
* @param guess
* the predicted distribution of class memberships
*/
public void addInfo(double[] softlabels, double[] guess) {
obs++;
for (int i = 0; i < numClasses; i++) {
for (int j = 0; j < numClasses; j++) {
confMatrix[i][j] += softlabels[i] * guess[j];
}
}
}
/**
* Adds a label / prediction pair to the confusion matrix with soft labels
*
* @param softlabels
* actual distribution of target class memberships
* @param guess
* the predicted distribution of class memberships
* @param freq
* the number of times to consider this label / prediction pair
*/
public void addInfo(double[] softlabels, double[] guess, double freq) {
obs += freq;
for (int i = 0; i < numClasses; i++) {
for (int j = 0; j < numClasses; j++) {
confMatrix[i][j] += softlabels[i] * guess[j] * freq;
}
}
}
/**
* Computes the actual distribution over labels
*
* @return the double[] encoding probabilities in each class.
*/
public double[] classDistribution() {
double[] dists = new double[this.numClasses];
for (int i = 0; i < numClasses; i++) {
dists[i] = classDistribution(i);
}
return dists;
}
/**
* Computes the actual probability of mambership in a particular class
* denoted by the input index
*
* @param num
* index of the class of interest
* @return the probability of membership in the requested class
*/
public double classDistribution(int num) {
double classSum = 0;
double totSum = 0;
for (int i = 0; i < numClasses; i++) {
for (int j = 0; j < numClasses; j++) {
if (i == num)
classSum += confMatrix[i][j];
totSum += confMatrix[i][j];
}
}
return classSum / totSum;
}
/**
* Adds a label / prediction pair to the confusion matrix with hard (most
* likely class) labels
*
* @param softlabels
* actual distribution of target class memberships
* @param guess
* the predicted distribution of class memberships
* @param freq
* the number of times to consider this label / prediction pair
*/
public void addHard(double[] softlabels, double[] guess, double weight) {
addInfo(softToHard(softlabels), softToHard(guess), weight);
}
/**
* Adds a label / prediction pair to the confusion matrix with hard (most
* likely class) labels
*
* @param softlabels
* actual distribution of target class memberships
* @param guess
* the predicted distribution of class memberships
*/
public void addHard(double[] softlabels, double[] guess) {
addInfo(softToHard(softlabels), softToHard(guess));
}
/**
* Adds a label / prediction pair to the confusion matrix with hard (most
* likely class) labels note, only applicable for binary classification (2
* class) problems
*
* @param label
* the index of the actual class of membership
* @param prediction
* the predicted probability of membership in the positive class
*/
public void addHard(int label, double[] guess) {
addInfo(label, softToHard(guess));
}
/**
* Adds a label / prediction pair to the confusion matrix with hard (most
* likely class) labels note, only applicable for binary classification (2
* class) problems
*
* @param label
* the index of the actual class of membership
* @param prediction
* the predicted probability of membership in the positive class
*/
public void addHard(int label, double prediction) {
addInfo(label,
softToHard(new double[] { 1.0 - prediction, prediction }));
}
/**
* Adds a label / prediction pair to the confusion matrix with hard (most
* likely class) labels note, only applicable for binary classification (2
* class) problems
*
* @param label
* the index of the actual class of membership
* @param prediction
* the predicted probability of membership in the positive class
* @param freq
* the number of times this label / prediction pair should be
* considered.
*/
public void addHard(int label, double[] guess, double freq) {
addInfo(label, softToHard(guess));
}
/**
* converts a soft prediction of probability estimates into a categorical
* indicator for the most likely class
*
* @param scores
* probabilities of label class membership
* @return the categorical values, 0's for all target classes with a 1 for
* the most likely class
*/
private static double[] softToHard(double[] scores) {
int maxindex = 0;
double max = 0;
double[] out = new double[scores.length];
for (int i = 0; i < scores.length; i++) {
if (scores[i] > max) {
maxindex = i;
max = scores[i];
}
}
out[maxindex] = 1;
return out;
}
/*
* (non-Javadoc)
*
* @see java.lang.Object#toString()
*/
@Override
public String toString() {
StringBuilder buff = new StringBuilder();
buff.append("predicted:\t");
for (int i = 0; i < numClasses - 1; i++) {
buff.append(i + "\t");
}
buff.append((numClasses - 1) + "\n");
for (int i = 0; i < numClasses; i++) {
buff.append("actually " + i + ":\t");
for (int j = 0; j < numClasses; j++) {
buff.append(String.format("%.4f\t", confMatrix[i][j]));
}
buff.append("\n");
}
return buff.toString();
}
/**
* To string row normalized (divided by the sum of each row)
*
* @return the string representation of the confusion matrix that has been
* row normalized
*/
public String toStringRowNormalized() {
StringBuilder buff = new StringBuilder();
buff.append("predicted:\t");
for (int i = 0; i < numClasses - 1; i++) {
buff.append(i + "\t");
}
double[] rowSums = this.rowSums();
buff.append((numClasses - 1) + "\n");
for (int i = 0; i < numClasses; i++) {
buff.append("actually " + i + ":\t");
for (int j = 0; j < numClasses; j++) {
String s = String.format("%.4f\t", confMatrix[i][j]
/ rowSums[i]);
buff.append(s);
}
buff.append("\n");
}
return buff.toString();
}
/**
* To string column normalized (divided by the sum of each column)
*
* @return the string representation of the confusion matrix that has been
* column normalized
*/
public String toStringColNormalized() {
StringBuilder buff = new StringBuilder();
buff.append("predicted:\t");
for (int i = 0; i < numClasses - 1; i++) {
buff.append(i + "\t");
}
double[] colSums = this.colSums();
buff.append((numClasses - 1) + "\n");
for (int i = 0; i < numClasses; i++) {
buff.append("actually " + i + ":\t");
for (int j = 0; j < numClasses; j++) {
String s = String.format("%.4f\t", confMatrix[i][j]
/ colSums[i]);
buff.append(s);
}
buff.append("\n");
}
return buff.toString();
}
/**
* Compute the sum of each row
*
* @return an array containing the sum of each row.
*/
public double[] rowSums() {
double[] sums = new double[numClasses];
for (int i = 0; i < numClasses; i++) {
for (int j = 0; j < numClasses; j++) {
sums[i] += confMatrix[i][j];
}
}
return sums;
}
/**
* Compute the accuracy for a given class; the % of examples that have been
* correctly classifieed.
*
* @param classid
* the index of the class where accuracy has been requested
* @return the % of correctly classified examples for the requested class
*/
public double computeAccuracy(int classid) {
double tn = 0.;
for (int i = 0; i < numClasses; i++)
for (int j = 0; j < numClasses; j++)
if (j != classid && i != classid)
tn += confMatrix[i][j];
double tp = confMatrix[classid][classid];
return (tn + tp) / obs;
}
public double computeAverageFmeasure() {
double[] rowSums = rowSums();
double total = total(rowSums);
double fmeasure = 0.;
for (int i = 0; i < numClasses; i++) {
fmeasure += rowSums[i] * computeFmeasure(i);
}
return fmeasure / total;
}
public double computeAveragePrecision() {
double[] rowSums = rowSums();
double total = total(rowSums);
double precision = 0.;
for (int i = 0; i < numClasses; i++) {
precision += rowSums[i] * computePrecision(i);
}
return precision / total;
}
public double computeAverageRecall() {
double[] rowSums = rowSums();
double total = total(rowSums);
double recall = 0.;
for (int i = 0; i < numClasses; i++) {
recall += rowSums[i] * computeRecall(i);
}
return recall / total;
}
/**
* Compute the sums of each column
*
* @return an array containing the sum of each column.
*/
public double[] colSums() {
double[] sums = new double[numClasses];
for (int i = 0; i < numClasses; i++) {
for (int j = 0; j < numClasses; j++) {
sums[j] += confMatrix[i][j];
}
}
return sums;
}
/**
* Return the confusion matrix as a 2d array
*
* @return the confusion matrix data structure
*/
public double[][] getMatrix() {
double[][] out = new double[numClasses][numClasses];
for (int i = 0; i < numClasses; i++)
for (int j = 0; j < numClasses; j++)
out[i][j] = confMatrix[i][j];
return out;
}
/**
* Gets the number of classes in the confusion matrix
*
* @return the number of classes considered
*/
public int getDim() {
return this.numClasses;
}
/**
* Computes the accuracy over all observations for all classes (% of
* correctly labeled examples).
*
* @return accuracy over all classes.
*/
public double computeAccuracy() {
double accy = 0;
double tot = 0;
double right = 0;
for (int i = 0; i < this.numClasses; i++) {
tot += total(this.confMatrix[i]);
right += this.confMatrix[i][i];
}
if (tot > 0) {
accy = right / tot;
}
return accy;
}
/**
* Compute the precision for each class; the % of members of labeled as
* belonging to each class who were actually members of that class
*
* @return an array containing the precision values for each class.
*/
public double[] computePrecision() {
double[] precision = new double[this.numClasses];
for (int i = 0; i < this.numClasses; i++) {
double yes = 0;
double no = 0;
for (int j = 0; j < this.numClasses; j++) {
if (i == j)
yes += confMatrix[i][j];
else
no += confMatrix[i][j];
}
if (yes + no != 0)
precision[i] = yes / (yes + no);
}
return precision;
}
/**
* Compute the recall for each class; the % of members of belonging to each
* class that were labeled as class members
*
* @return an array containing the recall values for each class.
*/
public double[] computeRecall() {
double[] recall = new double[this.numClasses];
double yes[] = new double[this.numClasses];
double no[] = new double[this.numClasses];
for (int i = 0; i < this.numClasses; i++) {
for (int j = 0; j < this.numClasses; j++) {
if (i == j)
yes[j] += confMatrix[i][j];
else
no[j] += confMatrix[i][j];
}
}
for (int i = 0; i < numClasses; i++) {
if (yes[i] + no[i] != 0)
recall[i] = yes[i] / (yes[i] + no[i]);
}
return recall;
}
/**
* Computes the F-measure for each class; the harmonic mean of precision and
* recall
* {@link F-Measure }
* for more info
*
* @return the array containing the F-measure for each class
*/
public double[] computeFmeasure() {
double[] fmeasure = new double[numClasses];
double[] precision = this.computePrecision();
double[] recall = this.computeRecall();
for (int i = 0; i < this.numClasses; i++) {
if (recall[i] + precision[i] != 0)
fmeasure[i] = 2.0 * (precision[i] * recall[i])
/ (precision[i] + recall[i]);
}
return fmeasure;
}
/**
* Builds a string table containing the common IR measures, precision,
* recall, and F measure for each class
*
* @return the string with performance stats
*/
public String getIR() {
StringBuffer buff = new StringBuffer();
buff.append("class\t" + "precision\t" + "recall\t" + "F measure\n");
double[] precision = this.computePrecision();
double[] recall = this.computeRecall();
double[] fmeasure = this.computeFmeasure();
for (int i = 0; i < numClasses; i++) {
buff.append(i + "\t" + precision[i] + "\t" + recall[i] + "\t"
+ fmeasure[i] + "\n");
}
return buff.toString();
}
/**
* Computes precision for a given class; the % of members of belonging to
* each class that were labeled as class members
*
* @param dim
* class of interest
* @return the precision for the requested class
*/
public double computePrecision(int dim) {
double tot = 0;
for (int i = 0; i < numClasses; i++)
tot += confMatrix[i][dim];
return confMatrix[dim][dim] / tot;
}
/**
* Compute the recall for a given class; the % of members of belonging to
* each class that were labeled as class members
*
* @param dim
* the class of interest
* @return the recall for the requested class
*/
public double computeRecall(int dim) {
double tot = 0;
for (int i = 0; i < numClasses; i++)
tot += confMatrix[dim][i];
return confMatrix[dim][dim] / tot;
}
/**
* Computes the F-measure for a given class; the harmonic mean of precision
* and recall
* {@link F-Measure }
* for more info
*
*
* @param dim
* the class of interest
* @return the F-Measure of the requested class
*/
public double computeFmeasure(int dim) {
double pre = computePrecision(dim);
double rec = computeRecall(dim);
return 2 * (pre * rec) / (pre + rec);
}
/**
* Total.
*
* @param arr
* the arr
* @return the double
*/
private double total(double[] arr) {
double total = 0;
for (int i = 0; i < arr.length; i++)
total += arr[i];
return total;
}
/**
* Builds a confusion matrix with the input observations and computes the
* accuracy over all observations for all classes (% of correctly labeled
* examples).
*
*
* @param input
* the input label / prediction pairs
* @return the accuracy of the input values
*/
public static double computeAccuracy(Collection input) {
ConfusionMatrix conf = new ConfusionMatrix(2);
for (PrimitivePair p : input)
conf.addInfo(new double[] { 1. - p.first, p.first }, new double[] {
1. - p.second, p.second });
return conf.computeAccuracy();
}
}
================================================
FILE: src/main/java/com/etsy/conjecture/evaluation/EvaluationAggregator.java
================================================
package com.etsy.conjecture.evaluation;
import java.io.Serializable;
import java.util.ArrayList;
import java.util.HashMap;
import java.util.List;
import java.util.Map;
import java.util.TreeMap;
import org.apache.commons.math3.stat.descriptive.DescriptiveStatistics;
import com.etsy.conjecture.data.Label;
public class EvaluationAggregator implements Serializable {
private static final long serialVersionUID = 5825037849957449364L;
protected Map stats = new TreeMap();
protected Map> obj = new HashMap>();
public void add(ModelEvaluation eval) {
Map fold = eval.getStatistics();
if (!stats.isEmpty()) {
if (!fold.keySet().equals(stats.keySet())) {
throw new java.lang.RuntimeException(
"Tried to add incompatible folds, with fields:"
+ fold.keySet().toString() + " and "
+ stats.keySet().toString());
}
for (Map.Entry e : fold.entrySet()) {
stats.get(e.getKey()).addValue(e.getValue());
}
for (Map.Entry e : eval.getObjects().entrySet()) {
obj.get(e.getKey()).add(e.getValue());
}
} else {
for (Map.Entry e : fold.entrySet()) {
DescriptiveStatistics ds = new DescriptiveStatistics();
ds.addValue(e.getValue());
stats.put(e.getKey(), ds);
}
for (Map.Entry e : eval.getObjects().entrySet()) {
obj.put(e.getKey(), new ArrayList(5));
obj.get(e.getKey()).add(e.getValue());
}
}
}
public double getValue(String key) {
return stats.get(key).getMean();
}
@Override
public String toString() {
StringBuilder buff = new StringBuilder("Stat:\tMean\tStdDev\tMedian\n");
for (Map.Entry e : stats.entrySet()) {
buff.append(e.getKey() + ":\t" + format(e.getValue()) + "\n");
}
for (Map.Entry> e : obj.entrySet()) {
buff.append(e.getKey()).append(":\n");
for (Object o : e.getValue()) {
buff.append("----\n").append(o.toString()).append("\n");
}
}
return buff.toString();
}
private String format(DescriptiveStatistics stats) {
return String.format("%.4f\t%.4f\t%.4f", stats.getMean(),
stats.getStandardDeviation(), stats.getPercentile(50));
}
}
================================================
FILE: src/main/java/com/etsy/conjecture/evaluation/ModelEvaluation.java
================================================
package com.etsy.conjecture.evaluation;
import com.etsy.conjecture.data.Label;
import java.util.Map;
public interface ModelEvaluation {
public void add(L real, L predicted);
public Map getStatistics();
public Map getObjects();
public void merge(ModelEvaluation other);
}
================================================
FILE: src/main/java/com/etsy/conjecture/evaluation/MulticlassConfusionMatrix.java
================================================
package com.etsy.conjecture.evaluation;
import java.io.Serializable;
import java.util.Map;
import java.util.Set;
import java.util.SortedMap;
import java.util.TreeMap;
/**
* class representing a confusion matrix for representing misclassification
* errors.
* {@link Confusion Matrix }
*
* @author jattenberg
*/
public class MulticlassConfusionMatrix implements Serializable {
private static final long serialVersionUID = 1L;
/**
* The data structure representing the confusion matrix. rows correspond to
* labels, columns to predictions
*/
private final SortedMap> confusionMatrix;
/** The num_classes represented in the confusion matrix */
private final int numClasses;
/** The number of label / prediction pairs observed */
double obs;
/**
* Instantiates a new confusion matrix.
*
* @param classes
* the number of target classes in the problem being considered
*/
public MulticlassConfusionMatrix(String[] categories) {
obs = 0;
this.numClasses = categories.length;
confusionMatrix = initializeMatrix(categories);
}
public void add(MulticlassConfusionMatrix m) {
obs += m.obs;
for(Map.Entry> entry : m.confusionMatrix.entrySet()) {
String label = entry.getKey();
SortedMap value = entry.getValue();
for(Map.Entry inner_entry : value.entrySet()) {
String inner_label = inner_entry.getKey();
Double update = inner_entry.getValue();
confusionMatrix.get(label).put(inner_label, update + getValue(label, inner_label));
}
}
}
private static SortedMap> initializeMatrix(
Set categories) {
String[] catArray = new String[categories.size()];
int ct = 0;
for (String category : categories) {
catArray[ct++] = category;
}
return initializeMatrix(catArray);
}
private static SortedMap> initializeMatrix(
String[] categories) {
SortedMap> conf = new TreeMap>();
for (String categoryOuter : categories) {
conf.put(categoryOuter, new TreeMap());
for (String categoryInner : categories) {
conf.get(categoryOuter).put(categoryInner, 0d);
}
}
return conf;
}
private Double getValue(String label, String guess) {
return confusionMatrix.get(label).get(guess);
}
private void updateConfusionMatrix(String label, String guess, double value) {
confusionMatrix.get(label).put(guess, value + getValue(label, guess));
}
private Map initializeProbabilityMatrix() {
Map probs = new TreeMap();
for (String category : confusionMatrix.keySet()) {
probs.put(category, 0d);
}
return probs;
}
/**
* Adds a label / prediction pair to the confusion matrix
*
* @param label
* the index of the actual class
* @param guess
* the index of the predicted class
*/
public void addInfo(String label, String guess) {
obs++;
updateConfusionMatrix(label, guess, 1d);
}
public void addInfo(String label, String guess, double freq) {
obs += freq;
updateConfusionMatrix(label, guess, freq);
}
/**
* Adds a label / prediction pair to the confusion matrix with soft labels
*
* @param label
* the index of the actual class
* @param guess
* the predicted distribution over classes.
*/
public void addInfo(String label, Map guesses) {
addInfo(label, guesses, 1d);
}
/**
* Adds a label / prediction pair to the confusion matrix with soft labels
*
* @param label
* the index of the actual class
* @param guess
* the predicted distribution over classes.
* @param freq
* the number of times to consider the input label / prediction
* pair
*/
public void addInfo(String label, Map predictions,
double freq) {
// TODO: ensure that sets match
obs += freq;
for (String category : predictions.keySet()) {
updateConfusionMatrix(label, category, predictions.get(category));
}
}
/**
* Adds a label / prediction pair to the confusion matrix with soft labels
*
* @param softlabels
* actual distribution of target class memberships
* @param guess
* the predicted distribution of class memberships
*/
public void addInfo(Map labels,
Map predictions) {
addInfo(labels, predictions, 1d);
}
/**
* Adds a label / prediction pair to the confusion matrix with soft labels
*
* @param softlabels
* actual distribution of target class memberships
* @param guess
* the predicted distribution of class memberships
* @param freq
* the number of times to consider this label / prediction pair
*/
public void addInfo(Map labels,
Map predictions, double freq) {
obs += freq;
for (String categoryLabel : labels.keySet()) {
for (String categoryGuess : predictions.keySet()) {
updateConfusionMatrix(categoryLabel, categoryGuess, freq);
}
}
}
/**
* Computes the actual distribution over labels
*
* @return the Map classDistribution() {
Map dists = initializeProbabilityMatrix();
for (String category : dists.keySet()) {
dists.put(category, classDistribution(category));
}
return dists;
}
/**
* Computes the actual probability of mambership in a particular class
* denoted by the input index TODO: implement this more efficiently
*
* @param num
* index of the class of interest
* @return the probability of membership in the requested class
*/
public double classDistribution(String category) {
double classSum = 0;
double totSum = 0;
for (String categoryLabel : confusionMatrix.keySet()) {
for (String categoryPrediction : confusionMatrix.keySet()) {
if (categoryPrediction.equals(categoryLabel)) {
classSum += getValue(categoryLabel, categoryPrediction);
}
totSum += getValue(categoryLabel, categoryPrediction);
}
}
return totSum > 0d ? classSum : 0d;
}
/**
* Adds a label / prediction pair to the confusion matrix with hard (most
* likely class) labels
*
* @param softlabels
* actual distribution of target class memberships
* @param guess
* the predicted distribution of class memberships
* @param freq
* the number of times to consider this label / prediction pair
*/
public void addHard(Map softlabels,
Map predictions, double weight) {
addInfo(softToHard(softlabels), softToHard(predictions), weight);
}
/**
* Adds a label / prediction pair to the confusion matrix with hard (most
* likely class) labels
*
* @param softlabels
* actual distribution of target class memberships
* @param guess
* the predicted distribution of class memberships
*/
public void addHard(Map softlabels,
Map predictions) {
addInfo(softToHard(softlabels), softToHard(predictions), 1d);
}
/**
* Adds a label / prediction pair to the confusion matrix with hard (most
* likely class) labels note, only applicable for binary classification (2
* class) problems
*
* @param label
* the index of the actual class of membership
* @param prediction
* the predicted probability of membership in the positive class
*/
public void addHard(String label, Map guess) {
addInfo(label, softToHard(guess), 1d);
}
/**
* Adds a label / prediction pair to the confusion matrix with hard (most
* likely class) labels note, only applicable for binary classification (2
* class) problems
*
* @param label
* the index of the actual class of membership
* @param prediction
* the predicted probability of membership in the positive class
* @param freq
* the number of times this label / prediction pair should be
* considered.
*/
public void addHard(String label, Map guess, double freq) {
addInfo(label, softToHard(guess), 1d);
}
/**
* converts a soft prediction of probability estimates into a categorical
* indicator for the most likely class
*
* @param scores
* probabilities of label class membership
* @return the categorical values, 0's for all target classes with a 1 for
* the most likely class
*/
private static String softToHard(Map scores) {
String maxindex = null;
double max = Double.NEGATIVE_INFINITY;
for (String category : scores.keySet()) {
if (scores.get(category) > max) {
max = scores.get(category);
maxindex = category;
}
}
return maxindex;
}
public String printDebug() {
return "";
}
/*
* (non-Javadoc)
*
* @see java.lang.Object#toString()
*/
public String toString() {
StringBuilder buff = new StringBuilder();
buff.append("predicted:\t");
for (String category : confusionMatrix.keySet()) {
buff.append(category + "\t");
}
buff.append("\n");
for (String categoryLabel : confusionMatrix.keySet()) {
buff.append("actually " + categoryLabel + ":\t");
for (String categoryPrediction : confusionMatrix.keySet()) {
buff.append(String.format("%.4f\t",
getValue(categoryLabel, categoryPrediction)));
}
buff.append("\n");
}
return buff.toString();
}
/**
* To string row normalized (divided by the sum of each row)
*
* @return the string representation of the confusion matrix that has been
* row normalized
*/
public String toStringRowNormalized() {
StringBuilder buff = new StringBuilder();
buff.append("predicted:\t");
for (String category : confusionMatrix.keySet()) {
buff.append(category + "\t");
}
Map rowSums = rowSums();
buff.append("\n");
for (String categoryLabel : confusionMatrix.keySet()) {
buff.append("actually " + categoryLabel + ":\t");
for (String categoryPrediction : confusionMatrix.keySet()) {
String s = String.format(
"%.4f\t",
getValue(categoryLabel, categoryPrediction)
/ rowSums.get(categoryLabel));
buff.append(s);
}
buff.append("\n");
}
return buff.toString();
}
/**
* To string column normalized (divided by the sum of each column)
*
* @return the string representation of the confusion matrix that has been
* column normalized
*/
public String toStringColNormalized() {
StringBuilder buff = new StringBuilder();
buff.append("predicted:\t");
for (String category : confusionMatrix.keySet()) {
buff.append(category + "\t");
}
Map colSums = colSums();
buff.append("\n");
for (String categoryLabel : confusionMatrix.keySet()) {
buff.append("actually " + categoryLabel + ":\t");
for (String categoryPrediction : confusionMatrix.keySet()) {
String s = String.format(
"%.4f\t",
getValue(categoryLabel, categoryPrediction)
/ colSums.get(categoryLabel));
buff.append(s);
}
buff.append("\n");
}
return buff.toString();
}
/**
* Compute the sum of each row
*
* @return an array containing the sum of each row.
*/
public Map rowSums() {
Map sums = initializeProbabilityMatrix();
for (String cateogryLabel : confusionMatrix.keySet()) {
for (String cateogryPrediction : confusionMatrix.keySet()) {
sums.put(
cateogryLabel,
sums.get(cateogryLabel)
+ getValue(cateogryLabel, cateogryPrediction));
}
}
return sums;
}
/**
* Compute the accuracy for a given class; the % of examples that have been
* correctly classifieed.
*
* @param classid
* the index of the class where accuracy has been requested
* @return the % of correctly classified examples for the requested class
*/
public double computeAccuracy(String classId) {
double tn = 0.;
for (String categoryLabel : confusionMatrix.keySet()) {
for (String categoryPrediction : confusionMatrix.keySet()) {
if (!categoryLabel.equals(classId)
&& !categoryPrediction.equals(classId))
tn += getValue(categoryLabel, categoryPrediction);
}
}
double tp = getValue(classId, classId);
return (tn + tp) / obs;
}
public double computeAverageFmeasure() {
Map rowSums = rowSums();
double total = total(rowSums);
double fmeasure = 0.;
for (String category : confusionMatrix.keySet()) {
fmeasure += rowSums.get(category) * computeFmeasure(category);
}
return fmeasure / total;
}
public double computeAveragePrecision() {
Map rowSums = rowSums();
double total = total(rowSums);
double precision = 0.;
for (String category : confusionMatrix.keySet()) {
double pre = computePrecision(category);
if (!Double.isNaN(pre)) { // when nothing is predicted as category,
// pre is NaN
precision += rowSums.get(category) * pre;
}
}
return precision / total;
}
public double computeAverageRecall() {
Map rowSums = rowSums();
double total = total(rowSums);
double recall = 0.;
for (String category : confusionMatrix.keySet()) {
double re = computeRecall(category);
if (!Double.isNaN(re)) { // re is NaN when there are 0 examples with
// label of category
recall += rowSums.get(category) * re;
}
}
return recall / total;
}
/**
* Compute the sums of each column
*
* @return an array containing the sum of each column.
*/
public Map colSums() {
Map sums = initializeProbabilityMatrix();
for (String categoryLabel : confusionMatrix.keySet()) {
for (String categoryPrediction : confusionMatrix.keySet()) {
sums.put(categoryPrediction, sums.get(categoryPrediction)
+ getValue(categoryLabel, categoryPrediction));
}
}
return sums;
}
/**
* Return the confusion matrix as a 2d array
*
* @return the confusion matrix data structure
*/
public SortedMap> getMatrix() {
SortedMap> out = initializeMatrix(confusionMatrix
.keySet());
for (String categoryLabel : confusionMatrix.keySet()) {
for (String categoryPrediction : confusionMatrix.keySet()) {
out.get(categoryLabel).put(categoryPrediction,
getValue(categoryLabel, categoryPrediction));
}
}
return out;
}
/**
* Gets the number of classes in the confusion matrix
*
* @return the number of classes considered
*/
public int getDim() {
return this.numClasses;
}
/**
* Computes the accuracy over all observations for all classes (% of
* correctly labeled examples).
*
* @return accuracy over all classes.
*/
public double computeAccuracy() {
double accy = 0d;
double tot = 0d;
double right = 0d;
Map rowSums = rowSums();
for (String category : confusionMatrix.keySet()) {
tot += rowSums.get(category);
right += getValue(category, category);
}
if (tot > 0) {
accy = right / tot;
}
return accy;
}
/**
* Compute the precision for each class; the % of members of labeled as
* belonging to each class who were actually members of that class
*
* @return an array containing the precision values for each class.
*/
public Map computePrecision() {
Map precision = initializeProbabilityMatrix();
for (String categoryLabel : confusionMatrix.keySet()) {
precision.put(categoryLabel, computePrecision(categoryLabel));
}
return precision;
}
/**
* Compute the recall for each class; the % of members of belonging to each
* class that were labeled as class members
*
* @return an array containing the recall values for each class.
*/
public Map computeRecall() {
Map recall = initializeProbabilityMatrix();
for (String categoryLabel : confusionMatrix.keySet()) {
recall.put(categoryLabel, computeRecall(categoryLabel));
}
return recall;
}
/**
* Computes the F-measure for each class; the harmonic mean of precision and
* recall
* {@link F-Measure }
* for more info
*
* @return the array containing the F-measure for each class
*/
public Map computeFmeasure() {
Map fmeasure = initializeProbabilityMatrix();
Map precision = this.computePrecision();
Map recall = this.computeRecall();
for (String category : confusionMatrix.keySet()) {
if (recall.get(category) + precision.get(category) != 0)
fmeasure.put(
category,
2.0
* (precision.get(category) * recall
.get(category))
/ (precision.get(category) + recall
.get(category)));
}
return fmeasure;
}
/**
* Builds A String Table Containing The common IR measures, precision,
* recall, and F measure for each class
*
* @return the string with performance stats
*/
public String getIR() {
StringBuffer buff = new StringBuffer();
buff.append("class\t" + "precision\t" + "recall\t" + "F measure\n");
Map precision = this.computePrecision();
Map recall = this.computeRecall();
Map fmeasure = this.computeFmeasure();
for (String category : confusionMatrix.keySet()) {
buff.append(category + "\t" + precision.get(category) + "\t"
+ recall.get(category) + "\t" + fmeasure.get(category)
+ "\n");
}
return buff.toString();
}
/**
* Computes precision for a given class; the % of members of belonging to
* each class that were labeled as class members
*
* @param dim
* class of interest
* @return the precision for the requested class
*/
public double computePrecision(String category) {
double tot = 0;
for (String label : confusionMatrix.keySet()) {
tot += getValue(label, category);
}
return getValue(category, category) / tot;
}
/**
* Compute the recall for a given class; the % of members of belonging to
* each class that were labeled as class members
*
* @param dim
* the class of interest
* @return the recall for the requested class
*/
public double computeRecall(String category) {
double tot = 0;
for (String prediction : confusionMatrix.keySet())
tot += getValue(category, prediction);
return getValue(category, category) / tot;
}
/**
* Computes the F-measure for a given class; the harmonic mean of precision
* and recall
* {@link F-Measure }
* for more info
*
*
* @param dim
* the class of interest
* @return the F-Measure of the requested class
*/
public double computeFmeasure(String category) {
double pre = computePrecision(category);
double rec = computeRecall(category);
return 2 * (pre * rec) / (pre + rec);
}
/**
* Total.
*
* @param arr
* the arr
* @return the double
*/
private double total(Map probs) {
double total = 0;
for (String category : probs.keySet())
total += probs.get(category);
return total;
}
}
================================================
FILE: src/main/java/com/etsy/conjecture/evaluation/MulticlassModelEvaluation.java
================================================
package com.etsy.conjecture.evaluation;
import java.io.Serializable;
import java.util.HashMap;
import java.util.Map;
import java.util.SortedMap;
import java.util.TreeMap;
import com.etsy.conjecture.GenericPair;
import com.etsy.conjecture.data.MulticlassLabel;
import com.etsy.conjecture.data.MulticlassPrediction;
/**
* a basic container for evaluations TODO: add getters for individual metrics
*/
public class MulticlassModelEvaluation implements Serializable,
ModelEvaluation {
/**
*
*/
private static final long serialVersionUID = 4916724871985109129L;
private final MulticlassReceiverOperatingCharacteristic ROC;
private final MulticlassConfusionMatrix conf;
private final String[] categories;
public MulticlassModelEvaluation(String[] categories) {
this.categories = categories;
ROC = new MulticlassReceiverOperatingCharacteristic(categories);
conf = new MulticlassConfusionMatrix(categories);
}
public void add(String label, MulticlassPrediction prediction) {
ROC.add(label, prediction);
conf.addInfo(label, prediction.getLabel());
}
public void merge(ModelEvaluation other) {
MulticlassModelEvaluation tempOther = (MulticlassModelEvaluation) other;
ROC.add(tempOther.ROC);
conf.add(tempOther.conf);
}
public void add(GenericPair labelPrediction) {
add(labelPrediction.first, labelPrediction.second);
}
public void add(MulticlassLabel real, MulticlassLabel pred) {
if (!(pred instanceof MulticlassPrediction)) {
throw new java.lang.RuntimeException(
"MulticlassModelEvaluation needs a MulticlassPrediction");
}
add(real.getLabel(), (MulticlassPrediction)pred);
}
public double computeAUC() {
return ROC.multiclassAUC();
}
public double computeAUC(String dim) {
return ROC.singleClassAUC(dim);
}
public double computeBrier() {
return ROC.multiclassBrierScore();
}
public double computeAccy() {
return conf.computeAccuracy();
}
public double computeAccy(String dim) {
return conf.computeAccuracy(dim);
}
public double computeFmeasure() {
return conf.computeAverageFmeasure();
}
public double computeFmeasure(String dim) {
return conf.computeFmeasure(dim);
}
public double computePrecision() {
return conf.computeAveragePrecision();
}
public double computePrecision(String dim) {
return conf.computePrecision(dim);
}
public double computeRecall() {
return conf.computeAverageRecall();
}
public double computeRecall(String dim) {
return conf.computeRecall(dim);
}
public double computePercent(String dim) {
return ROC.computePercent(dim);
}
public String printDebug() {
return conf.printDebug();
}
public Map getStatistics() {
SortedMap m = new TreeMap();
m.put("AUC (avg)", computeAUC());
m.put("Brier (avg)", computeBrier());
m.put("Acc (avg)", computeAccy());
m.put("F1 (avg)", computeFmeasure());
m.put("Prc (avg)", computePrecision());
m.put("Rec (avg)", computeRecall());
for (String category : categories) {
m.put(category + ": Pct", computePercent(category));
m.put(category + ": AUC", computeAUC(category));
m.put(category + ": Acc", computeAccy(category));
m.put(category + ": F1", computeFmeasure(category));
m.put(category + ": Prc", computePrecision(category));
m.put(category + ": Rec", computeRecall(category));
}
return m;
}
public String toString() {
StringBuilder buff = new StringBuilder();
buff.append("AUC: " + ROC.multiclassAUC() + "\n");
buff.append("Brier: " + ROC.multiclassBrierScore() + "\n");
buff.append("IR metrics:\n" + conf.getIR() + "\n");
buff.append("Confusion Matrix:\n" + conf.toString() + "\n");
return buff.toString();
}
public HashMap getObjects() {
HashMap m = new HashMap