Showing preview only (225K chars total). Download the full file or copy to clipboard to get everything.
Repository: datquocnguyen/LFTM
Branch: master
Commit: 2d074e52c6b9
Files: 23
Total size: 12.7 MB
Directory structure:
gitextract_qkmmkaa4/
├── License.txt
├── README.md
├── build.xml
├── jar/
│ └── LFTM.jar
├── lib/
│ ├── args4j-2.0.6.jar
│ └── mallet.jar
├── src/
│ ├── LFTM.java
│ ├── eval/
│ │ └── ClusteringEval.java
│ ├── models/
│ │ ├── LFDMM.java
│ │ ├── LFDMM_Inf.java
│ │ ├── LFLDA.java
│ │ ├── LFLDA_Inf.java
│ │ └── TopicVectorOptimizer.java
│ └── utility/
│ ├── CmdArgs.java
│ ├── FuncUtils.java
│ ├── LBFGS.java
│ ├── MTRandom.java
│ ├── MersenneTwister.java
│ └── Parallel.java
└── test/
├── corpus.LABEL
├── corpus.txt
├── corpus_test.txt
└── wordVectors.txt
================================================
FILE CONTENTS
================================================
================================================
FILE: License.txt
================================================
Implementations of the LF-LDA and LF-DMM latent feature topic models
Copyright (C) 2015-2016 by Dat Quoc Nguyen
dat.nguyen@students.mq.edu.au
Department of Computing, Macquarie University, Australia
This program is free software; you can redistribute it and/or modify it under the terms of the GNU General Public License as published by the Free Software Foundation; either version 2 of the License, or (at your option) any later version.
This program is distributed in the hope that it will be useful, but WITHOUT ANY WARRANTY; without even the implied warranty of MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the GNU General Public License for more details.
You should have received a copy of the GNU General Public License along with this program. If not, see <http://www.gnu.org/licenses/>
================================================
FILE: README.md
================================================
# LF-LDA and LF-DMM latent feature topic models
The implementations of the LF-LDA and LF-DMM latent feature topic models, as described in my TACL paper:
Dat Quoc Nguyen, Richard Billingsley, Lan Du and Mark Johnson. [Improving Topic Models with Latent Feature Word Representations](https://tacl2013.cs.columbia.edu/ojs/index.php/tacl/article/view/582/158). <i>Transactions of the Association for Computational Linguistics</i>, vol. 3, pp. 299-313, 2015. [[.bib]](http://web.science.mq.edu.au/~dqnguyen/papers/TACL.bib) [[Datasets]](http://web.science.mq.edu.au/~dqnguyen/papers/TACL-datasets.zip) [[Example_20Newsgroups_20Topics_Top50Words]](http://web.science.mq.edu.au/~dqnguyen/papers/TACL_TopWords_N20_20Topics.zip)
The implementations of the LDA and DMM topic models are available at [http://jldadmm.sourceforge.net/](http://jldadmm.sourceforge.net/)
## Usage
This section describes the usage of the implementations in command line or terminal, using the pre-compiled `LFTM.jar` file.
Here, it is expected that Java 1.7+ is already set to run in command line or terminal (for example: adding Java to the `path` environment variable in Windows OS).
The pre-compiled `LFTM.jar` file and source codes are in the `jar` and `src` folders, respectively. Users can recompile the source codes by simply running `ant` (it is also expected that `ant` is already installed). In addition, the users can find input examples in the `test` folder.
#### File format of input topic-modeling corpus
Similar to the `corpus.txt` file in the `test` folder, each line in the input topic-modeling corpus represents a document. Here, a document is a sequence words/tokens separated by white space characters. The users should preprocess the input topic-modeling corpus before training the topic models, for example: down-casing, removing non-alphabetic characters and stop-words, removing words shorter than 3 characters and words appearing less than a certain times.
#### Format of input word-vector file
Similar to the `wordVectors.txt` file in the `test` folder, each line in the input word-vector file starts with a word type which is followed by a vector representation.
To obtain the vector representations of words, the users can use `the pre-trained word vectors learned from large external corpora` OR `the word vectors which are trained on the input topic-modeling corpus`.
In case of using the pre-trained word vectors learned from the large external corpora, the users have to remove words in the input topic-modeling corpus, in which these words are not found in the input word-vector file.
Some sets of the pre-trained word vectors can be found at:
[Word2Vec: https://code.google.com/p/word2vec/](https://code.google.com/p/word2vec/)
[Glove: http://nlp.stanford.edu/projects/glove/](http://nlp.stanford.edu/projects/glove/)
If the input topic-modeling corpus is too domain-specific, the domain of the external corpus (from which the word vectors are derived) should not be too different to that of the input topic-modeling corpus. For example, when applying to the biomedical domain, the users may use Word2Vec or Glove to learn 50 or 100-dimensional word vectors on the large external MEDLINE corpus instead of using the pre-trained Word2Vec or Glove word vectors.
### Training LF-LDA and LF-DMM
`$ java [-Xmx2G] -jar jar/LFTM.jar –model <LFLDA_or_LFDMM> -corpus <Input_corpus_file_path> -vectors <Input_vector_file_path> [-ntopics <int>] [-alpha <double>] [-beta <double>] [-lambda <double>] [-initers <int>] [-niters <int>] [-twords <int>] [-name <String>] [-sstep <int>]`
where hyper-parameters in [ ] are optional.
* `-model`: Specify the topic model.
* `-corpus`: Specify the path to the input training corpus file.
* `-vectors`: Specify the path to the file containing word vectors.
* `-ntopics <int>`: Specify the number of topics. The default value is 20.
* `-alpha <double>`: Specify the hyper-parameter alpha. Following [1, 2], the default value is 0.1.
* `-beta <double>`: Specify the hyper-parameter beta. The default value is 0.01. Following [2], you might also want to try beta value of 0.1 for short texts.
* `-lambda <double>`: Specify the mixture weight lambda (0.0 < lambda <= 1.0). Set the mixture weight lambda to be 1.0 to obtain the best topic coherence.
In case of document clustering/classification evaluation, fine-tune this parameter to obtain the highest results if you have time; otherwise try both values 0.6 and 1.0 (I would suggest to set lambda 0.6 for normal text corpora and 1.0 for short text corpora if you don't have time to try both 0.6 and 1.0).
* `-initers <int>`: Specify the number of initial sampling iterations to separate the counts for the latent feature component and the Dirichlet multinomial component. The default value is 2000.
* `-niters <int>`: Specify the number of sampling iterations for the latent feature topic models. The default value is 200.
* `-twords <int>`: Specify the number of the most probable topical words. The default value is 20.
* `-name <String>`: Specify a name to the topic modeling experiment. The default value is “model”.
* `-sstep <int>`: Specify a step to save the sampling output (`-sstep` value < `-niters` value). The default value is 0 (i.e. only saving the output from the last sample).
NOTE that (topic vectors are learned in parallel, so) running LFTM code with multiple CPU/core machine to obtain a significantly faster training process, e.g. using a multi-core computer, or set number of CPUs requested for a remote job to be equal to number of topics.
<b>Examples:</b>
`$ java -jar jar/LFTM.jar -model LFLDA -corpus test/corpus.txt -vectors test/wordVectors.txt -ntopics 4 -alpha 0.1 -beta 0.01 -lambda 0.6 -initers 500 -niters 50 -name testLFLDA`
Basically, with this command we run 500 `LDA` sampling iterations (i.e., `-initers 500`) for initialization and then run 50 `LF-LDA` sampling iterations (i.e., `-niters 50`). The output files are saved in the same folder as the input training corpus file, in this case in the `test` folder. We have output files of `testLFLDA.theta`, `testLFLDA.phi`, `testLFLDA.topWords`, `testLFLDA.topicAssignments` and `testLFLDA.paras`, referring to the document-to-topic distributions, topic-to-word distributions, top topical words, topic assignments and model hyper-parameters, respectively. Similarly, we perform:
`$ java -jar jar/LFTM.jar -model LFDMM -corpus test/corpus.txt -vectors test/wordVectors.txt -ntopics 4 -alpha 0.1 -beta 0.1 -lambda 1.0 -initers 500 -niters 50 -name testLFDMM`
We have output files of `testLFDMM.theta`, `testLFDMM.phi`, `testLFDMM.topWords`, `testLFDMM.topicAssignments` and `testLFDMM.paras`.
In the LF-LDA and LF-DMM latent feature topic models, a word is generated by the latent feature topic-to-word component OR by the topic-to-word Dirichlet multinomial component. In practical implementation, instead of using a binary selection variable to record that, I simply add a value of the number of topics to the actual topic assignment value. For example with 20 topics, the output topic assignment `3 23 4 4 24 3 23 3 23 3 23` for a document means that the first word in the document is generated from topic 3 by the latent feature topic-to-word component. The second word is also generated from the topic `23 - 20 = 3`, but by the topic-to-word Dirichlet multinomial component. It is similar for the remaining words in the document.
### Document clustering evaluation
Here, we treat each topic as a cluster, and we assign every document the topic with the highest probability given the document. To get the clustering scores of Purity and normalized mutual information, we perform:
`$ java –jar jar/LFTM.jar –model Eval –label <Golden_label_file_path> -dir <Directory_path> -prob <Document-topic-prob/Suffix>`
* `–label`: Specify the path to the ground truth label file. Each line in this label file contains the golden label of the corresponding document in the input training corpus. See the `corpus.LABEL` and `corpus.txt` files in the `test` folder.
* `-dir`: Specify the path to the directory containing document-to-topic distribution files.
* `-prob`: Specify a document-to-topic distribution file or a group of document-to-topic distribution files in the specified directory.
<b>Examples:</b>
The command `$ java -jar jar/LFTM.jar -model Eval -label test/corpus.LABEL -dir test -prob testLFLDA.theta` will produce the clustering score for the `testLFLDA.theta` file.
The command `$ java -jar jar/LFTM.jar -model Eval -label test/corpus.LABEL -dir test -prob testLFDMM.theta` will produce the clustering score for `testLFDMM.theta` file.
The command `$ java -jar jar/LFTM.jar -model Eval -label test/corpus.LABEL -dir test -prob theta` will produce the clustering scores for all the document-to-topic distribution files having names ended by `theta`. In this case, the distribution files are `testLFLDA.theta` and `testLFDMM.theta`. It also provides the mean and standard deviation of the clustering scores.
### Inference of topic distribution on unseen corpus
To infer topics on an unseen/new corpus using a pre-trained LF-LDA/LF-DMM topic model, we perform:
`$ java -jar jar/LFTM.jar -model <LFLDAinf_or_LFDMMinf> -paras <Hyperparameter_file_path> -corpus <Unseen_corpus_file_path> [-initers <int>] [-niters <int>] [-twords <int>] [-name <String>] [-sstep <int>]`
* `-paras`: Specify the path to the hyper-parameter file produced by the pre-trained LF-LDA/LF-DMM topic model.
<b>Examples:</b>
`$ java -jar jar/LFTM.jar -model LFLDAinf -paras test/testLFLDA.paras -corpus test/corpus_test.txt -initers 500 -niters 50 -name testLFLDAinf`
`$ java -jar jar/LFTM.jar -model LFDMMinf -paras test/testLFDMM.paras -corpus test/corpus_test.txt -initers 500 -niters 50 -name testLFDMMinf`
## Acknowledgments
The LF-LDA and LF-DMM implementations used utilities including the LBFGS implementation from [MALLET toolkit](http://mallet.cs.umass.edu/), the random number generator in [Java version of MersenneTwister](http://cs.gmu.edu/~sean/research/), the `Parallel.java` utility from [Mines Java Toolkit](http://dhale.github.io/jtk/api/edu/mines/jtk/util/Parallel.html) and the [Java command line arguments parser](http://args4j.kohsuke.org/sample.html). I would like to thank the authors of the mentioned utilities for sharing the codes.
## References
[1] Yue Lu, Qiaozhu Mei, and ChengXiang Zhai. 2011. Investigating task performance of probabilistic topic models: an empirical study of PLSA and LDA. Information Retrieval, 14:178–203.
[2] Jianhua Yin and Jianyong Wang. 2014. A Dirichlet Multinomial Mixture Model-based Approach for Short Text Clustering. In Proceedings of the 20th ACM SIGKDD International Conference on Knowledge Discovery and Data Mining, pages 233–242.
================================================
FILE: build.xml
================================================
<?xml version="1.0" encoding="UTF-8"?>
<project name="LFTM" basedir="." default="main">
<property name="src.dir" value="src"/>
<property name="lib.dir" location="lib" />
<property name="classes.dir" value="bin"/>
<property name="jar.dir" value="jar"/>
<property name="main-class" value="LFTM"/>
<path id="build.classpath">
<fileset dir="${lib.dir}">
<include name="**/*.jar" />
</fileset>
</path>
<target name="clean">
<delete dir="${classes.dir}"/>
<delete dir="${jar.dir}"/>
</target>
<presetdef name="javac">
<javac includeantruntime="false" />
</presetdef>
<target name="compile">
<mkdir dir="${classes.dir}"/>
<javac srcdir="${src.dir}" destdir="${classes.dir}" classpathref="build.classpath"/>
</target>
<target name="jar" depends="compile">
<mkdir dir="${jar.dir}"/>
<jar destfile="${jar.dir}/${ant.project.name}.jar" basedir="${classes.dir}">
<manifest>
<attribute name="Class-Path" value="${build.classpath}"/>
<attribute name="Main-Class" value="${main-class}"/>
</manifest>
<zipgroupfileset dir="${lib.dir}"/>
</jar>
</target>
<target name="run" depends="jar">
<java fork="true" classname="${main-class}">
<classpath>
<path refid="build.classpath"/>
<pathelement location="${jar.dir}/${ant.project.name}.jar"/>
</classpath>
</java>
</target>
<target name="clean-build" depends="clean,jar"/>
<target name="main" depends="clean,compile,jar"/>
</project>
================================================
FILE: src/LFTM.java
================================================
import models.LFDMM;
import models.LFDMM_Inf;
import models.LFLDA;
import models.LFLDA_Inf;
import org.kohsuke.args4j.CmdLineException;
import org.kohsuke.args4j.CmdLineParser;
import utility.CmdArgs;
import eval.ClusteringEval;
/**
* Implementations of the LF-LDA and LF-DMM latent feature topic models, using
* collapsed Gibbs sampling, as described in:
*
* Dat Quoc Nguyen, Richard Billingsley, Lan Du and Mark Johnson. 2015.
* Improving Topic Models with Latent Feature Word Representations. Transactions
* of the Association for Computational Linguistics, vol. 3, pp. 299-313.
*
* @author Dat Quoc Nguyen
*
*/
public class LFTM
{
public static void main(String[] args)
{
CmdArgs cmdArgs = new CmdArgs();
CmdLineParser parser = new CmdLineParser(cmdArgs);
try {
parser.parseArgument(args);
if (cmdArgs.model.equals("LFLDA")) {
LFLDA lflda = new LFLDA(cmdArgs.corpus, cmdArgs.vectors,
cmdArgs.ntopics, cmdArgs.alpha, cmdArgs.beta,
cmdArgs.lambda, cmdArgs.initers, cmdArgs.niters,
cmdArgs.twords, cmdArgs.expModelName,
cmdArgs.initTopicAssgns, cmdArgs.savestep);
lflda.inference();
}
else if (cmdArgs.model.equals("LFDMM")) {
LFDMM lfdmm = new LFDMM(cmdArgs.corpus, cmdArgs.vectors,
cmdArgs.ntopics, cmdArgs.alpha, cmdArgs.beta,
cmdArgs.lambda, cmdArgs.initers, cmdArgs.niters,
cmdArgs.twords, cmdArgs.expModelName,
cmdArgs.initTopicAssgns, cmdArgs.savestep);
lfdmm.inference();
}
else if (cmdArgs.model.equals("LFLDAinf")) {
LFLDA_Inf lfldaInf = new LFLDA_Inf(cmdArgs.paras,
cmdArgs.corpus, cmdArgs.initers, cmdArgs.niters,
cmdArgs.twords, cmdArgs.expModelName, cmdArgs.savestep);
lfldaInf.inference();
}
else if (cmdArgs.model.equals("LFDMMinf")) {
LFDMM_Inf lfdmmInf = new LFDMM_Inf(cmdArgs.paras,
cmdArgs.corpus, cmdArgs.initers, cmdArgs.niters,
cmdArgs.twords, cmdArgs.expModelName, cmdArgs.savestep);
lfdmmInf.inference();
}
else if (cmdArgs.model.equals("Eval")) {
ClusteringEval.evaluate(cmdArgs.labelFile, cmdArgs.dir,
cmdArgs.prob);
}
else {
System.out
.println("Error: Option \"-model\" must get \"LFLDA\" or \"LFDMM\" or \"LFLDAinf\" or \"LFDMMinf\" or \"Eval\"");
System.out.println("\tLFLDA: Specify the LF-LDA topic model");
System.out.println("\tLFDMM: Specify the LF-DMM topic model");
System.out
.println("\tLFLDAinf: Infer topics for unseen corpus using a pre-trained LF-LDA model");
System.out
.println("\tLFDMMinf: Infer topics for unseen corpus using a pre-trained LF-DMM model");
System.out
.println("\tEval: Specify the document clustering evaluation");
help(parser);
return;
}
}
catch (CmdLineException cle) {
System.out.println("Error: " + cle.getMessage());
help(parser);
return;
}
catch (Exception e) {
System.out.println("Error: " + e.getMessage());
e.printStackTrace();
return;
}
}
public static void help(CmdLineParser parser)
{
System.out.println("java -jar LFTM.jar [options ...] [arguments...]");
parser.printUsage(System.out);
}
}
================================================
FILE: src/eval/ClusteringEval.java
================================================
package eval;
import java.io.BufferedReader;
import java.io.BufferedWriter;
import java.io.File;
import java.io.FileReader;
import java.io.FileWriter;
import java.util.ArrayList;
import java.util.HashMap;
import java.util.HashSet;
import java.util.List;
import java.util.Set;
import utility.FuncUtils;
/**
* Implementation of the Purity and NMI clustering evaluation scores, as described in Section 16.3
* in:
*
* Christopher D. Manning, Prabhakar Raghavan, and Hinrich Sch¨utze. 2008. Introduction to
* Information Retrieval. Cambridge University Press.
*
* @author: Dat Quoc Nguyen
*/
public class ClusteringEval
{
String pathDocTopicProsFile;
String pathGoldenLabelsFile;
HashMap<String, Set<Integer>> goldenClusers;
HashMap<String, Set<Integer>> outputClusers;
int numDocs;
public ClusteringEval(String inPathGoldenLabelsFile, String inPathDocTopicProsFile)
throws Exception
{
pathDocTopicProsFile = inPathDocTopicProsFile;
pathGoldenLabelsFile = inPathGoldenLabelsFile;
goldenClusers = new HashMap<String, Set<Integer>>();
outputClusers = new HashMap<String, Set<Integer>>();
readGoldenLabelsFile();
readDocTopicProsFile();
}
public void readGoldenLabelsFile()
throws Exception
{
System.out.println("Reading golden labels file " + pathGoldenLabelsFile);
int id = 0;
BufferedReader br = null;
try {
br = new BufferedReader(new FileReader(pathGoldenLabelsFile));
for (String label; (label = br.readLine()) != null;) {
label = label.trim();
Set<Integer> ids = new HashSet<Integer>();
if (goldenClusers.containsKey(label))
ids = goldenClusers.get(label);
ids.add(id);
goldenClusers.put(label, ids);
id += 1;
}
}
catch (Exception e) {
e.printStackTrace();
}
numDocs = id;
}
public void readDocTopicProsFile()
throws Exception
{
System.out.println("Reading document-to-topic distribution file " + pathDocTopicProsFile);
HashMap<Integer, String> docLabelOutput = new HashMap<Integer, String>();
int docIndex = 0;
BufferedReader br = null;
try {
br = new BufferedReader(new FileReader(pathDocTopicProsFile));
for (String docTopicProbs; (docTopicProbs = br.readLine()) != null;) {
String[] pros = docTopicProbs.trim().split("\\s+");
double maxPro = 0.0;
int index = -1;
for (int topicIndex = 0; topicIndex < pros.length; topicIndex++) {
double pro = new Double(pros[topicIndex]);
if (pro > maxPro) {
maxPro = pro;
index = topicIndex;
}
}
docLabelOutput.put(docIndex, "Topic_" + new Integer(index).toString());
docIndex++;
}
}
catch (Exception e) {
e.printStackTrace();
}
if (numDocs != docIndex) {
System.out
.println("Error: the number of documents is different to the number of labels!");
throw new Exception();
}
for (Integer id : docLabelOutput.keySet()) {
String label = docLabelOutput.get(id);
Set<Integer> ids = new HashSet<Integer>();
if (outputClusers.containsKey(label))
ids = outputClusers.get(label);
ids.add(id);
outputClusers.put(label, ids);
}
}
public double computePurity()
{
int count = 0;
for (String label : outputClusers.keySet()) {
Set<Integer> docs = outputClusers.get(label);
int correctAssignedDocNum = 0;
for (String goldenLabel : goldenClusers.keySet()) {
Set<Integer> goldenDocs = goldenClusers.get(goldenLabel);
Set<Integer> outputDocs = new HashSet<Integer>(docs);
outputDocs.retainAll(goldenDocs);
if (outputDocs.size() >= correctAssignedDocNum)
correctAssignedDocNum = outputDocs.size();
}
count += correctAssignedDocNum;
}
double value = count * 1.0 / numDocs;
System.out.println("\tPurity accuracy: " + value);
return value;
}
public double computeNMIscore()
{
double MIscore = 0.0;
for (String label : outputClusers.keySet()) {
Set<Integer> docs = outputClusers.get(label);
for (String goldenLabel : goldenClusers.keySet()) {
Set<Integer> goldenDocs = goldenClusers.get(goldenLabel);
Set<Integer> outputDocs = new HashSet<Integer>(docs);
outputDocs.retainAll(goldenDocs);
double numCorrectAssignedDocs = outputDocs.size() * 1.0;
if (numCorrectAssignedDocs == 0.0)
continue;
MIscore += (numCorrectAssignedDocs / numDocs)
* Math.log(numCorrectAssignedDocs * numDocs
/ (docs.size() * goldenDocs.size()));
}
}
double entropy = 0.0;
for (String label : outputClusers.keySet()) {
Set<Integer> docs = outputClusers.get(label);
entropy += (-1.0 * docs.size() / numDocs) * Math.log(1.0 * docs.size() / numDocs);
}
for (String label : goldenClusers.keySet()) {
Set<Integer> docs = goldenClusers.get(label);
entropy += (-1.0 * docs.size() / numDocs) * Math.log(1.0 * docs.size() / numDocs);
}
double value = 2 * MIscore / entropy;
System.out.println("\tNMI score: " + value);
return value;
}
public static void evaluate(String pathGoldenLabelsFile,
String pathToFolderOfDocTopicProsFiles, String suffix)
throws Exception
{
BufferedWriter writer = new BufferedWriter(new FileWriter(pathToFolderOfDocTopicProsFiles
+ "/" + suffix + ".PurityNMI"));
writer.write("Golden-labels in: " + pathGoldenLabelsFile + "\n\n");
File[] files = new File(pathToFolderOfDocTopicProsFiles).listFiles();
List<Double> purity = new ArrayList<Double>(), nmi = new ArrayList<Double>();
for (File file : files) {
if (!file.getName().endsWith(suffix))
continue;
writer.write("Results for: " + file.getAbsolutePath() + "\n");
ClusteringEval dce = new ClusteringEval(pathGoldenLabelsFile, file.getAbsolutePath());
double value = dce.computePurity();
writer.write("\tPurity: " + value + "\n");
purity.add(value);
value = dce.computeNMIscore();
writer.write("\tNMI: " + value + "\n");
nmi.add(value);
}
if (purity.size() == 0 || nmi.size() == 0) {
System.out.println("Error: There is no file ending with " + suffix);
throw new Exception();
}
double[] purityValues = new double[purity.size()];
double[] nmiValues = new double[nmi.size()];
for (int i = 0; i < purity.size(); i++)
purityValues[i] = purity.get(i).doubleValue();
for (int i = 0; i < nmi.size(); i++)
nmiValues[i] = nmi.get(i).doubleValue();
writer.write("\n---\nMean purity: " + FuncUtils.mean(purityValues)
+ ", standard deviation: " + FuncUtils.stddev(purityValues));
writer.write("\nMean NMI: " + FuncUtils.mean(nmiValues) + ", standard deviation: "
+ FuncUtils.stddev(nmiValues));
System.out.println("---\nMean purity: " + FuncUtils.mean(purityValues)
+ ", standard deviation: " + FuncUtils.stddev(purityValues));
System.out.println("Mean NMI: " + FuncUtils.mean(nmiValues) + ", standard deviation: "
+ FuncUtils.stddev(nmiValues));
writer.close();
}
public static void main(String[] args)
throws Exception
{
ClusteringEval.evaluate("test/corpus.LABEL", "test", "theta");
}
}
================================================
FILE: src/models/LFDMM.java
================================================
package models;
import java.io.BufferedReader;
import java.io.BufferedWriter;
import java.io.FileReader;
import java.io.FileWriter;
import java.io.IOException;
import java.util.ArrayList;
import java.util.HashMap;
import java.util.List;
import java.util.Map;
import java.util.Set;
import java.util.TreeMap;
import utility.FuncUtils;
import utility.LBFGS;
import utility.Parallel;
import cc.mallet.optimize.InvalidOptimizableException;
import cc.mallet.optimize.Optimizer;
import cc.mallet.types.MatrixOps;
import cc.mallet.util.Randoms;
/**
* Implementation of the LF-DMM latent feature topic model, using collapsed Gibbs sampling, as
* described in:
*
* Dat Quoc Nguyen, Richard Billingsley, Lan Du and Mark Johnson. 2015. Improving Topic Models with
* Latent Feature Word Representations. Transactions of the Association for Computational
* Linguistics, vol. 3, pp. 299-313.
*
* @author Dat Quoc Nguyen
*/
public class LFDMM
{
public double alpha; // Hyper-parameter alpha
public double beta; // Hyper-parameter alpha
// public double alphaSum; // alpha * numTopics
public double betaSum; // beta * vocabularySize
public int numTopics; // Number of topics
public int topWords; // Number of most probable words for each topic
public double lambda; // Mixture weight value
public int numInitIterations;
public int numIterations; // Number of EM-style sampling iterations
public List<List<Integer>> corpus; // Word ID-based corpus
public List<List<Integer>> topicAssignments; // Topics assignments for words
// in the corpus
public int numDocuments; // Number of documents in the corpus
public int numWordsInCorpus; // Number of words in the corpus
public HashMap<String, Integer> word2IdVocabulary; // Vocabulary to get ID
// given a word
public HashMap<Integer, String> id2WordVocabulary; // Vocabulary to get word
// given an ID
public int vocabularySize; // The number of word types in the corpus
// Number of documents assigned to a topic
public int[] docTopicCount;
// numTopics * vocabularySize matrix
// Given a topic: number of times a word type generated from the topic by
// the Dirichlet multinomial component
public int[][] topicWordCountDMM;
// Total number of words generated from each topic by the Dirichlet
// multinomial component
public int[] sumTopicWordCountDMM;
// numTopics * vocabularySize matrix
// Given a topic: number of times a word type generated from the topic by
// the latent feature component
public int[][] topicWordCountLF;
// Total number of words generated from each topic by the latent feature
// component
public int[] sumTopicWordCountLF;
// Double array used to sample a topic
public double[] multiPros;
// Path to the directory containing the corpus
public String folderPath;
// Path to the topic modeling corpus
public String corpusPath;
public String vectorFilePath;
public double[][] wordVectors; // Vector representations for words
public double[][] topicVectors;// Vector representations for topics
public int vectorSize; // Number of vector dimensions
public double[][] dotProductValues;
public double[][] expDotProductValues;
public double[] sumExpValues; // Partition function values
public final double l2Regularizer = 0.01; // L2 regularizer value for learning topic vectors
public final double tolerance = 0.05; // Tolerance value for LBFGS convergence
public String expName = "LFDMM";
public String orgExpName = "LFDMM";
public String tAssignsFilePath = "";
public int savestep = 0;
public LFDMM(String pathToCorpus, String pathToWordVectorsFile, int inNumTopics,
double inAlpha, double inBeta, double inLambda, int inNumInitIterations,
int inNumIterations, int inTopWords)
throws Exception
{
this(pathToCorpus, pathToWordVectorsFile, inNumTopics, inAlpha, inBeta, inLambda,
inNumInitIterations, inNumIterations, inTopWords, "LFDMM");
}
public LFDMM(String pathToCorpus, String pathToWordVectorsFile, int inNumTopics,
double inAlpha, double inBeta, double inLambda, int inNumInitIterations,
int inNumIterations, int inTopWords, String inExpName)
throws Exception
{
this(pathToCorpus, pathToWordVectorsFile, inNumTopics, inAlpha, inBeta, inLambda,
inNumInitIterations, inNumIterations, inTopWords, inExpName, "", 0);
}
public LFDMM(String pathToCorpus, String pathToWordVectorsFile, int inNumTopics,
double inAlpha, double inBeta, double inLambda, int inNumInitIterations,
int inNumIterations, int inTopWords, String inExpName, String pathToTAfile)
throws Exception
{
this(pathToCorpus, pathToWordVectorsFile, inNumTopics, inAlpha, inBeta, inLambda,
inNumInitIterations, inNumIterations, inTopWords, inExpName, pathToTAfile, 0);
}
public LFDMM(String pathToCorpus, String pathToWordVectorsFile, int inNumTopics,
double inAlpha, double inBeta, double inLambda, int inNumInitIterations,
int inNumIterations, int inTopWords, String inExpName, int inSaveStep)
throws Exception
{
this(pathToCorpus, pathToWordVectorsFile, inNumTopics, inAlpha, inBeta, inLambda,
inNumInitIterations, inNumIterations, inTopWords, inExpName, "", inSaveStep);
}
public LFDMM(String pathToCorpus, String pathToWordVectorsFile, int inNumTopics,
double inAlpha, double inBeta, double inLambda, int inNumInitIterations,
int inNumIterations, int inTopWords, String inExpName, String pathToTAfile,
int inSaveStep)
throws Exception
{
alpha = inAlpha;
beta = inBeta;
lambda = inLambda;
numTopics = inNumTopics;
numIterations = inNumIterations;
numInitIterations = inNumInitIterations;
topWords = inTopWords;
savestep = inSaveStep;
expName = inExpName;
orgExpName = expName;
vectorFilePath = pathToWordVectorsFile;
corpusPath = pathToCorpus;
folderPath = pathToCorpus.substring(0,
Math.max(pathToCorpus.lastIndexOf("/"), pathToCorpus.lastIndexOf("\\")) + 1);
System.out.println("Reading topic modeling corpus: " + pathToCorpus);
word2IdVocabulary = new HashMap<String, Integer>();
id2WordVocabulary = new HashMap<Integer, String>();
corpus = new ArrayList<List<Integer>>();
numDocuments = 0;
numWordsInCorpus = 0;
BufferedReader br = null;
try {
int indexWord = -1;
br = new BufferedReader(new FileReader(pathToCorpus));
for (String doc; (doc = br.readLine()) != null;) {
if (doc.trim().length() == 0)
continue;
String[] words = doc.trim().split("\\s+");
List<Integer> document = new ArrayList<Integer>();
for (String word : words) {
if (word2IdVocabulary.containsKey(word)) {
document.add(word2IdVocabulary.get(word));
}
else {
indexWord += 1;
word2IdVocabulary.put(word, indexWord);
id2WordVocabulary.put(indexWord, word);
document.add(indexWord);
}
}
numDocuments++;
numWordsInCorpus += document.size();
corpus.add(document);
}
}
catch (Exception e) {
e.printStackTrace();
}
vocabularySize = word2IdVocabulary.size();
docTopicCount = new int[numTopics];
topicWordCountDMM = new int[numTopics][vocabularySize];
sumTopicWordCountDMM = new int[numTopics];
topicWordCountLF = new int[numTopics][vocabularySize];
sumTopicWordCountLF = new int[numTopics];
multiPros = new double[numTopics];
for (int i = 0; i < numTopics; i++) {
multiPros[i] = 1.0 / numTopics;
}
// alphaSum = numTopics * alpha;
betaSum = vocabularySize * beta;
readWordVectorsFile(vectorFilePath);
topicVectors = new double[numTopics][vectorSize];
dotProductValues = new double[numTopics][vocabularySize];
expDotProductValues = new double[numTopics][vocabularySize];
sumExpValues = new double[numTopics];
System.out
.println("Corpus size: " + numDocuments + " docs, " + numWordsInCorpus + " words");
System.out.println("Vocabuary size: " + vocabularySize);
System.out.println("Number of topics: " + numTopics);
System.out.println("alpha: " + alpha);
System.out.println("beta: " + beta);
System.out.println("lambda: " + lambda);
System.out.println("Number of initial sampling iterations: " + numInitIterations);
System.out.println("Number of EM-style sampling iterations for the LF-DMM model: "
+ numIterations);
System.out.println("Number of top topical words: " + topWords);
tAssignsFilePath = pathToTAfile;
if (tAssignsFilePath.length() > 0)
initialize(tAssignsFilePath);
else
initialize();
}
public void readWordVectorsFile(String pathToWordVectorsFile)
throws Exception
{
System.out.println("Reading word vectors from word-vectors file " + pathToWordVectorsFile
+ "...");
BufferedReader br = null;
try {
br = new BufferedReader(new FileReader(pathToWordVectorsFile));
String[] elements = br.readLine().trim().split("\\s+");
vectorSize = elements.length - 1;
wordVectors = new double[vocabularySize][vectorSize];
String word = elements[0];
if (word2IdVocabulary.containsKey(word)) {
for (int j = 0; j < vectorSize; j++) {
wordVectors[word2IdVocabulary.get(word)][j] = new Double(elements[j + 1]);
}
}
for (String line; (line = br.readLine()) != null;) {
elements = line.trim().split("\\s+");
word = elements[0];
if (word2IdVocabulary.containsKey(word)) {
for (int j = 0; j < vectorSize; j++) {
wordVectors[word2IdVocabulary.get(word)][j] = new Double(elements[j + 1]);
}
}
}
}
catch (Exception e) {
e.printStackTrace();
}
for (int i = 0; i < vocabularySize; i++) {
if (MatrixOps.absNorm(wordVectors[i]) == 0.0) {
System.out.println("The word \"" + id2WordVocabulary.get(i)
+ "\" doesn't have a corresponding vector!!!");
throw new Exception();
}
}
}
public void initialize()
throws IOException
{
System.out.println("Randomly initialzing topic assignments ...");
topicAssignments = new ArrayList<List<Integer>>();
for (int docId = 0; docId < numDocuments; docId++) {
List<Integer> topics = new ArrayList<Integer>();
int topic = FuncUtils.nextDiscrete(multiPros);
docTopicCount[topic] += 1;
int docSize = corpus.get(docId).size();
for (int j = 0; j < docSize; j++) {
int wordId = corpus.get(docId).get(j);
boolean component = new Randoms().nextBoolean();
int subtopic = topic;
if (!component) { // Generated from the latent feature component
topicWordCountLF[topic][wordId] += 1;
sumTopicWordCountLF[topic] += 1;
}
else {// Generated from the Dirichlet multinomial component
topicWordCountDMM[topic][wordId] += 1;
sumTopicWordCountDMM[topic] += 1;
subtopic = subtopic + numTopics;
}
topics.add(subtopic);
}
topicAssignments.add(topics);
}
}
public void initialize(String pathToTopicAssignmentFile)
throws Exception
{
System.out.println("Reading topic-assignment file: " + pathToTopicAssignmentFile);
topicAssignments = new ArrayList<List<Integer>>();
BufferedReader br = null;
try {
br = new BufferedReader(new FileReader(pathToTopicAssignmentFile));
int docId = 0;
int numWords = 0;
for (String line; (line = br.readLine()) != null;) {
String[] strTopics = line.trim().split("\\s+");
List<Integer> topics = new ArrayList<Integer>();
int topic = new Integer(strTopics[0]) % numTopics;
docTopicCount[topic] += 1;
for (int j = 0; j < strTopics.length; j++) {
int wordId = corpus.get(docId).get(j);
int subtopic = new Integer(strTopics[j]);
if (subtopic == topic) {
topicWordCountLF[topic][wordId] += 1;
sumTopicWordCountLF[topic] += 1;
}
else {
topicWordCountDMM[topic][wordId] += 1;
sumTopicWordCountDMM[topic] += 1;
}
topics.add(subtopic);
numWords++;
}
topicAssignments.add(topics);
docId++;
}
if ((docId != numDocuments) || (numWords != numWordsInCorpus)) {
System.out
.println("The topic modeling corpus and topic assignment file are not consistent!!!");
throw new Exception();
}
}
catch (Exception e) {
e.printStackTrace();
}
}
public void inference()
throws IOException
{
System.out.println("Running Gibbs sampling inference: ");
for (int iter = 1; iter <= numInitIterations; iter++) {
System.out.println("\tInitial sampling iteration: " + (iter));
sampleSingleInitialIteration();
}
for (int iter = 1; iter <= numIterations; iter++) {
System.out.println("\tLFDMM sampling iteration: " + (iter));
optimizeTopicVectors();
sampleSingleIteration();
if ((savestep > 0) && (iter % savestep == 0) && (iter < numIterations)) {
System.out.println("\t\tSaving the output from the " + iter + "^{th} sample");
expName = orgExpName + "-" + iter;
write();
}
}
expName = orgExpName;
writeParameters();
System.out.println("Writing output from the last sample ...");
write();
System.out.println("Sampling completed!");
}
public void optimizeTopicVectors()
{
System.out.println("\t\tEstimating topic vectors ...");
sumExpValues = new double[numTopics];
dotProductValues = new double[numTopics][vocabularySize];
expDotProductValues = new double[numTopics][vocabularySize];
Parallel.loop(numTopics, new Parallel.LoopInt()
{
@Override
public void compute(int topic)
{
int rate = 1;
boolean check = true;
while (check) {
double l2Value = l2Regularizer * rate;
try {
TopicVectorOptimizer optimizer = new TopicVectorOptimizer(
topicVectors[topic], topicWordCountLF[topic], wordVectors, l2Value);
Optimizer gd = new LBFGS(optimizer, tolerance);
gd.optimize(600);
optimizer.getParameters(topicVectors[topic]);
sumExpValues[topic] = optimizer.computePartitionFunction(
dotProductValues[topic], expDotProductValues[topic]);
check = false;
if (sumExpValues[topic] == 0 || Double.isInfinite(sumExpValues[topic])) {
double max = -1000000000.0;
for (int index = 0; index < vocabularySize; index++) {
if (dotProductValues[topic][index] > max)
max = dotProductValues[topic][index];
}
for (int index = 0; index < vocabularySize; index++) {
expDotProductValues[topic][index] = Math
.exp(dotProductValues[topic][index] - max);
sumExpValues[topic] += expDotProductValues[topic][index];
}
}
}
catch (InvalidOptimizableException e) {
e.printStackTrace();
check = true;
}
rate = rate * 10;
}
}
});
}
public void sampleSingleIteration()
{
for (int dIndex = 0; dIndex < numDocuments; dIndex++) {
List<Integer> document = corpus.get(dIndex);
int docSize = document.size();
int topic = topicAssignments.get(dIndex).get(0) % numTopics;
docTopicCount[topic] = docTopicCount[topic] - 1;
for (int wIndex = 0; wIndex < docSize; wIndex++) {
int word = document.get(wIndex);// wordId
int subtopic = topicAssignments.get(dIndex).get(wIndex);
if (subtopic == topic) {
topicWordCountLF[topic][word] -= 1;
sumTopicWordCountLF[topic] -= 1;
}
else {
topicWordCountDMM[topic][word] -= 1;
sumTopicWordCountDMM[topic] -= 1;
}
}
// Sample a topic
for (int tIndex = 0; tIndex < numTopics; tIndex++) {
multiPros[tIndex] = (docTopicCount[tIndex] + alpha);
for (int wIndex = 0; wIndex < docSize; wIndex++) {
int word = document.get(wIndex);
multiPros[tIndex] *= (lambda * expDotProductValues[tIndex][word]
/ sumExpValues[tIndex] + (1 - lambda)
* (topicWordCountDMM[tIndex][word] + beta)
/ (sumTopicWordCountDMM[tIndex] + betaSum));
}
}
topic = FuncUtils.nextDiscrete(multiPros);
docTopicCount[topic] += 1;
for (int wIndex = 0; wIndex < docSize; wIndex++) {
int word = document.get(wIndex);
int subtopic = topic;
if (lambda * expDotProductValues[topic][word] / sumExpValues[topic] > (1 - lambda)
* (topicWordCountDMM[topic][word] + beta)
/ (sumTopicWordCountDMM[topic] + betaSum)) {
topicWordCountLF[topic][word] += 1;
sumTopicWordCountLF[topic] += 1;
}
else {
topicWordCountDMM[topic][word] += 1;
sumTopicWordCountDMM[topic] += 1;
subtopic += numTopics;
}
// Update topic assignments
topicAssignments.get(dIndex).set(wIndex, subtopic);
}
}
}
public void sampleSingleInitialIteration()
{
for (int dIndex = 0; dIndex < numDocuments; dIndex++) {
List<Integer> document = corpus.get(dIndex);
int docSize = document.size();
int topic = topicAssignments.get(dIndex).get(0) % numTopics;
docTopicCount[topic] = docTopicCount[topic] - 1;
for (int wIndex = 0; wIndex < docSize; wIndex++) {
int word = document.get(wIndex);
int subtopic = topicAssignments.get(dIndex).get(wIndex);
if (topic == subtopic) {
topicWordCountLF[topic][word] -= 1;
sumTopicWordCountLF[topic] -= 1;
}
else {
topicWordCountDMM[topic][word] -= 1;
sumTopicWordCountDMM[topic] -= 1;
}
}
// Sample a topic
for (int tIndex = 0; tIndex < numTopics; tIndex++) {
multiPros[tIndex] = (docTopicCount[tIndex] + alpha);
for (int wIndex = 0; wIndex < docSize; wIndex++) {
int word = document.get(wIndex);
multiPros[tIndex] *= (lambda * (topicWordCountLF[tIndex][word] + beta)
/ (sumTopicWordCountLF[tIndex] + betaSum) + (1 - lambda)
* (topicWordCountDMM[tIndex][word] + beta)
/ (sumTopicWordCountDMM[tIndex] + betaSum));
}
}
topic = FuncUtils.nextDiscrete(multiPros);
docTopicCount[topic] += 1;
for (int wIndex = 0; wIndex < docSize; wIndex++) {
int word = document.get(wIndex);// wordID
int subtopic = topic;
if (lambda * (topicWordCountLF[topic][word] + beta)
/ (sumTopicWordCountLF[topic] + betaSum) > (1 - lambda)
* (topicWordCountDMM[topic][word] + beta)
/ (sumTopicWordCountDMM[topic] + betaSum)) {
topicWordCountLF[topic][word] += 1;
sumTopicWordCountLF[topic] += 1;
}
else {
topicWordCountDMM[topic][word] += 1;
sumTopicWordCountDMM[topic] += 1;
subtopic += numTopics;
}
// Update topic assignments
topicAssignments.get(dIndex).set(wIndex, subtopic);
}
}
}
public void writeParameters()
throws IOException
{
BufferedWriter writer = new BufferedWriter(new FileWriter(folderPath + expName + ".paras"));
writer.write("-model" + "\t" + "LFDMM");
writer.write("\n-corpus" + "\t" + corpusPath);
writer.write("\n-vectors" + "\t" + vectorFilePath);
writer.write("\n-ntopics" + "\t" + numTopics);
writer.write("\n-alpha" + "\t" + alpha);
writer.write("\n-beta" + "\t" + beta);
writer.write("\n-lambda" + "\t" + lambda);
writer.write("\n-initers" + "\t" + numInitIterations);
writer.write("\n-niters" + "\t" + numIterations);
writer.write("\n-twords" + "\t" + topWords);
writer.write("\n-name" + "\t" + expName);
if (tAssignsFilePath.length() > 0)
writer.write("\n-initFile" + "\t" + tAssignsFilePath);
if (savestep > 0)
writer.write("\n-sstep" + "\t" + savestep);
writer.close();
}
public void writeDictionary()
throws IOException
{
BufferedWriter writer = new BufferedWriter(new FileWriter(folderPath + expName
+ ".vocabulary"));
for (String word : word2IdVocabulary.keySet()) {
writer.write(word + " " + word2IdVocabulary.get(word) + "\n");
}
writer.close();
}
public void writeIDbasedCorpus()
throws IOException
{
BufferedWriter writer = new BufferedWriter(new FileWriter(folderPath + expName
+ ".IDcorpus"));
for (int dIndex = 0; dIndex < numDocuments; dIndex++) {
int docSize = corpus.get(dIndex).size();
for (int wIndex = 0; wIndex < docSize; wIndex++) {
writer.write(corpus.get(dIndex).get(wIndex) + " ");
}
writer.write("\n");
}
writer.close();
}
public void writeTopicAssignments()
throws IOException
{
BufferedWriter writer = new BufferedWriter(new FileWriter(folderPath + expName
+ ".topicAssignments"));
for (int dIndex = 0; dIndex < numDocuments; dIndex++) {
int docSize = corpus.get(dIndex).size();
for (int wIndex = 0; wIndex < docSize; wIndex++) {
writer.write(topicAssignments.get(dIndex).get(wIndex) + " ");
}
writer.write("\n");
}
writer.close();
}
public void writeTopicVectors()
throws IOException
{
BufferedWriter writer = new BufferedWriter(new FileWriter(folderPath + expName
+ ".topicVectors"));
for (int i = 0; i < numTopics; i++) {
for (int j = 0; j < vectorSize; j++)
writer.write(topicVectors[i][j] + " ");
writer.write("\n");
}
writer.close();
}
public void writeTopTopicalWords()
throws IOException
{
BufferedWriter writer = new BufferedWriter(new FileWriter(folderPath + expName
+ ".topWords"));
for (int tIndex = 0; tIndex < numTopics; tIndex++) {
writer.write("Topic" + new Integer(tIndex) + ":");
Map<Integer, Double> topicWordProbs = new TreeMap<Integer, Double>();
for (int wIndex = 0; wIndex < vocabularySize; wIndex++) {
double pro = lambda * expDotProductValues[tIndex][wIndex] / sumExpValues[tIndex]
+ (1 - lambda) * (topicWordCountDMM[tIndex][wIndex] + beta)
/ (sumTopicWordCountDMM[tIndex] + betaSum);
topicWordProbs.put(wIndex, pro);
}
topicWordProbs = FuncUtils.sortByValueDescending(topicWordProbs);
Set<Integer> mostLikelyWords = topicWordProbs.keySet();
int count = 0;
for (Integer index : mostLikelyWords) {
if (count < topWords) {
writer.write(" " + id2WordVocabulary.get(index));
count += 1;
}
else {
writer.write("\n\n");
break;
}
}
}
writer.close();
}
public void writeTopicWordPros()
throws IOException
{
BufferedWriter writer = new BufferedWriter(new FileWriter(folderPath + expName + ".phi"));
for (int t = 0; t < numTopics; t++) {
for (int w = 0; w < vocabularySize; w++) {
double pro = lambda * expDotProductValues[t][w] / sumExpValues[t] + (1 - lambda)
* (topicWordCountDMM[t][w] + beta) / (sumTopicWordCountDMM[t] + betaSum);
writer.write(pro + " ");
}
writer.write("\n");
}
writer.close();
}
public void writeDocTopicPros()
throws IOException
{
BufferedWriter writer = new BufferedWriter(new FileWriter(folderPath + expName + ".theta"));
for (int i = 0; i < numDocuments; i++) {
int docSize = corpus.get(i).size();
double sum = 0.0;
for (int tIndex = 0; tIndex < numTopics; tIndex++) {
multiPros[tIndex] = (docTopicCount[tIndex] + alpha);
for (int wIndex = 0; wIndex < docSize; wIndex++) {
int word = corpus.get(i).get(wIndex);
multiPros[tIndex] *= (lambda * expDotProductValues[tIndex][word]
/ sumExpValues[tIndex] + (1 - lambda)
* (topicWordCountDMM[tIndex][word] + beta)
/ (sumTopicWordCountDMM[tIndex] + betaSum));
}
sum += multiPros[tIndex];
}
for (int tIndex = 0; tIndex < numTopics; tIndex++) {
writer.write((multiPros[tIndex] / sum) + " ");
}
writer.write("\n");
}
writer.close();
}
public void write()
throws IOException
{
writeTopTopicalWords();
writeDocTopicPros();
writeTopicAssignments();
writeTopicWordPros();
}
public static void main(String args[])
throws Exception
{
LFDMM lfdmm = new LFDMM("test/corpus.txt", "test/wordVectors.txt", 4, 0.1, 0.01, 0.6, 2000,
200, 20, "testLFDMM");
lfdmm.writeParameters();
lfdmm.inference();
}
}
================================================
FILE: src/models/LFDMM_Inf.java
================================================
package models;
import java.io.BufferedReader;
import java.io.BufferedWriter;
import java.io.FileReader;
import java.io.FileWriter;
import java.io.IOException;
import java.util.ArrayList;
import java.util.HashMap;
import java.util.List;
import java.util.Map;
import java.util.Set;
import java.util.TreeMap;
import utility.FuncUtils;
import utility.LBFGS;
import utility.Parallel;
import cc.mallet.optimize.InvalidOptimizableException;
import cc.mallet.optimize.Optimizer;
import cc.mallet.types.MatrixOps;
import cc.mallet.util.Randoms;
/**
* Implementation of the LF-DMM latent feature topic model, using collapsed
* Gibbs sampling, as described in:
*
* Dat Quoc Nguyen, Richard Billingsley, Lan Du and Mark Johnson. 2015.
* Improving Topic Models with Latent Feature Word Representations. Transactions
* of the Association for Computational Linguistics, vol. 3, pp. 299-313.
*
* Inference of topic distribution on unseen corpus
*
* @author Dat Quoc Nguyen
*/
public class LFDMM_Inf
{
public double alpha; // Hyper-parameter alpha
public double beta; // Hyper-parameter alpha
// public double alphaSum; // alpha * numTopics
public double betaSum; // beta * vocabularySize
public int numTopics; // Number of topics
public int topWords; // Number of most probable words for each topic
public double lambda; // Mixture weight value
public int numInitIterations;
public int numIterations; // Number of EM-style sampling iterations
public List<List<Integer>> corpus; // Word ID-based corpus
public List<List<Integer>> topicAssignments; // Topics assignments for words
// in the corpus
public int numDocuments; // Number of documents in the corpus
public int numWordsInCorpus; // Number of words in the corpus
public HashMap<String, Integer> word2IdVocabulary; // Vocabulary to get ID
// given a word
public HashMap<Integer, String> id2WordVocabulary; // Vocabulary to get word
// given an ID
public int vocabularySize; // The number of word types in the corpus
// Number of documents assigned to a topic
public int[] docTopicCount;
// numTopics * vocabularySize matrix
// Given a topic: number of times a word type generated from the topic by
// the Dirichlet multinomial component
public int[][] topicWordCountDMM;
// Total number of words generated from each topic by the Dirichlet
// multinomial component
public int[] sumTopicWordCountDMM;
// numTopics * vocabularySize matrix
// Given a topic: number of times a word type generated from the topic by
// the latent feature component
public int[][] topicWordCountLF;
// Total number of words generated from each topic by the latent feature
// component
public int[] sumTopicWordCountLF;
// Double array used to sample a topic
public double[] multiPros;
// Path to the directory containing the corpus
public String folderPath;
// Path to the topic modeling corpus
public String corpusPath;
public String vectorFilePath;
public double[][] wordVectors; // Vector representations for words
public double[][] topicVectors;// Vector representations for topics
public int vectorSize; // Number of vector dimensions
public double[][] dotProductValues;
public double[][] expDotProductValues;
public double[] sumExpValues; // Partition function values
public final double l2Regularizer = 0.01; // L2 regularizer value for
// learning topic vectors
public final double tolerance = 0.05; // Tolerance value for LBFGS
// convergence
public String expName = "LFDMMinf";
public String orgExpName = "LFDMMinf";
public int savestep = 0;
public LFDMM_Inf(String pathToTrainingParasFile, String pathToUnseenCorpus,
int inNumInitIterations, int inNumIterations, int inTopWords,
String inExpName, int inSaveStep)
throws Exception
{
HashMap<String, String> paras = parseTrainingParasFile(pathToTrainingParasFile);
if (!paras.get("-model").equals("LFDMM")) {
throw new Exception("Wrong pre-trained model!!!");
}
alpha = new Double(paras.get("-alpha"));
beta = new Double(paras.get("-beta"));
lambda = new Double(paras.get("-lambda"));
numTopics = new Integer(paras.get("-ntopics"));
numIterations = inNumIterations;
numInitIterations = inNumInitIterations;
topWords = inTopWords;
savestep = inSaveStep;
expName = inExpName;
orgExpName = expName;
vectorFilePath = paras.get("-vectors");
String trainingCorpus = paras.get("-corpus");
String trainingCorpusfolder = trainingCorpus.substring(
0,
Math.max(trainingCorpus.lastIndexOf("/"),
trainingCorpus.lastIndexOf("\\")) + 1);
String topicAssignment4TrainFile = trainingCorpusfolder
+ paras.get("-name") + ".topicAssignments";
word2IdVocabulary = new HashMap<String, Integer>();
id2WordVocabulary = new HashMap<Integer, String>();
initializeWordCount(trainingCorpus, topicAssignment4TrainFile);
corpusPath = pathToUnseenCorpus;
folderPath = pathToUnseenCorpus.substring(
0,
Math.max(pathToUnseenCorpus.lastIndexOf("/"),
pathToUnseenCorpus.lastIndexOf("\\")) + 1);
System.out.println("Reading unseen corpus: " + pathToUnseenCorpus);
corpus = new ArrayList<List<Integer>>();
numDocuments = 0;
numWordsInCorpus = 0;
BufferedReader br = null;
try {
br = new BufferedReader(new FileReader(pathToUnseenCorpus));
for (String doc; (doc = br.readLine()) != null;) {
if (doc.trim().length() == 0)
continue;
String[] words = doc.trim().split("\\s+");
List<Integer> document = new ArrayList<Integer>();
for (String word : words) {
if (word2IdVocabulary.containsKey(word)) {
document.add(word2IdVocabulary.get(word));
}
else {
// Skip this unknown-word
}
}
numDocuments++;
numWordsInCorpus += document.size();
corpus.add(document);
}
}
catch (Exception e) {
e.printStackTrace();
}
docTopicCount = new int[numTopics];
multiPros = new double[numTopics];
for (int i = 0; i < numTopics; i++) {
multiPros[i] = 1.0 / numTopics;
}
// alphaSum = numTopics * alpha;
betaSum = vocabularySize * beta;
readWordVectorsFile(vectorFilePath);
topicVectors = new double[numTopics][vectorSize];
dotProductValues = new double[numTopics][vocabularySize];
expDotProductValues = new double[numTopics][vocabularySize];
sumExpValues = new double[numTopics];
System.out.println("Corpus size: " + numDocuments + " docs, "
+ numWordsInCorpus + " words");
System.out.println("Vocabuary size: " + vocabularySize);
System.out.println("Number of topics: " + numTopics);
System.out.println("alpha: " + alpha);
System.out.println("beta: " + beta);
System.out.println("lambda: " + lambda);
System.out.println("Number of initial sampling iterations: "
+ numInitIterations);
System.out
.println("Number of EM-style sampling iterations for the LF-DMM model: "
+ numIterations);
System.out.println("Number of top topical words: " + topWords);
initialize();
}
private HashMap<String, String> parseTrainingParasFile(
String pathToTrainingParasFile)
throws Exception
{
HashMap<String, String> paras = new HashMap<String, String>();
BufferedReader br = null;
try {
br = new BufferedReader(new FileReader(pathToTrainingParasFile));
for (String line; (line = br.readLine()) != null;) {
if (line.trim().length() == 0)
continue;
String[] paraOptions = line.trim().split("\\s+");
paras.put(paraOptions[0], paraOptions[1]);
}
}
catch (Exception e) {
e.printStackTrace();
}
return paras;
}
private void initializeWordCount(String pathToTrainingCorpus,
String pathToTopicAssignmentFile)
{
System.out.println("Loading pre-trained model...");
List<List<Integer>> trainCorpus = new ArrayList<List<Integer>>();
BufferedReader br = null;
try {
int indexWord = -1;
br = new BufferedReader(new FileReader(pathToTrainingCorpus));
for (String doc; (doc = br.readLine()) != null;) {
if (doc.trim().length() == 0)
continue;
String[] words = doc.trim().split("\\s+");
List<Integer> document = new ArrayList<Integer>();
for (String word : words) {
if (word2IdVocabulary.containsKey(word)) {
document.add(word2IdVocabulary.get(word));
}
else {
indexWord += 1;
word2IdVocabulary.put(word, indexWord);
id2WordVocabulary.put(indexWord, word);
document.add(indexWord);
}
}
trainCorpus.add(document);
}
}
catch (Exception e) {
e.printStackTrace();
}
vocabularySize = word2IdVocabulary.size();
topicWordCountDMM = new int[numTopics][vocabularySize];
sumTopicWordCountDMM = new int[numTopics];
topicWordCountLF = new int[numTopics][vocabularySize];
sumTopicWordCountLF = new int[numTopics];
try {
br = new BufferedReader(new FileReader(pathToTopicAssignmentFile));
int docId = 0;
for (String line; (line = br.readLine()) != null;) {
String[] strTopics = line.trim().split("\\s+");
int topic = new Integer(strTopics[0]) % numTopics;
for (int j = 0; j < strTopics.length; j++) {
int wordId = trainCorpus.get(docId).get(j);
int subtopic = new Integer(strTopics[j]);
if (subtopic == topic) {
topicWordCountLF[topic][wordId] += 1;
sumTopicWordCountLF[topic] += 1;
}
else {
topicWordCountDMM[topic][wordId] += 1;
sumTopicWordCountDMM[topic] += 1;
}
}
docId++;
}
}
catch (Exception e) {
e.printStackTrace();
}
}
public void readWordVectorsFile(String pathToWordVectorsFile)
throws Exception
{
System.out.println("Reading word vectors from word-vectors file "
+ pathToWordVectorsFile + "...");
BufferedReader br = null;
try {
br = new BufferedReader(new FileReader(pathToWordVectorsFile));
String[] elements = br.readLine().trim().split("\\s+");
vectorSize = elements.length - 1;
wordVectors = new double[vocabularySize][vectorSize];
String word = elements[0];
if (word2IdVocabulary.containsKey(word)) {
for (int j = 0; j < vectorSize; j++) {
wordVectors[word2IdVocabulary.get(word)][j] = new Double(
elements[j + 1]);
}
}
for (String line; (line = br.readLine()) != null;) {
elements = line.trim().split("\\s+");
word = elements[0];
if (word2IdVocabulary.containsKey(word)) {
for (int j = 0; j < vectorSize; j++) {
wordVectors[word2IdVocabulary.get(word)][j] = new Double(
elements[j + 1]);
}
}
}
}
catch (Exception e) {
e.printStackTrace();
}
for (int i = 0; i < vocabularySize; i++) {
if (MatrixOps.absNorm(wordVectors[i]) == 0.0) {
System.out.println("The word \"" + id2WordVocabulary.get(i)
+ "\" doesn't have a corresponding vector!!!");
throw new Exception();
}
}
}
public void initialize()
throws IOException
{
System.out.println("Randomly initialzing topic assignments ...");
topicAssignments = new ArrayList<List<Integer>>();
for (int docId = 0; docId < numDocuments; docId++) {
List<Integer> topics = new ArrayList<Integer>();
int topic = FuncUtils.nextDiscrete(multiPros);
docTopicCount[topic] += 1;
int docSize = corpus.get(docId).size();
for (int j = 0; j < docSize; j++) {
int wordId = corpus.get(docId).get(j);
boolean component = new Randoms().nextBoolean();
int subtopic = topic;
if (!component) { // Generated from the latent feature component
topicWordCountLF[topic][wordId] += 1;
sumTopicWordCountLF[topic] += 1;
}
else {// Generated from the Dirichlet multinomial component
topicWordCountDMM[topic][wordId] += 1;
sumTopicWordCountDMM[topic] += 1;
subtopic = subtopic + numTopics;
}
topics.add(subtopic);
}
topicAssignments.add(topics);
}
}
public void inference()
throws IOException
{
System.out.println("Running Gibbs sampling inference: ");
for (int iter = 1; iter <= numInitIterations; iter++) {
System.out.println("\tInitial sampling iteration: " + (iter));
sampleSingleInitialIteration();
}
for (int iter = 1; iter <= numIterations; iter++) {
System.out.println("\tLFDMM sampling iteration: " + (iter));
optimizeTopicVectors();
sampleSingleIteration();
if ((savestep > 0) && (iter % savestep == 0)
&& (iter < numIterations)) {
System.out.println("\t\tSaving the output from the " + iter
+ "^{th} sample");
expName = orgExpName + "-" + iter;
write();
}
}
expName = orgExpName;
writeParameters();
System.out.println("Writing output from the last sample ...");
write();
System.out.println("Sampling completed!");
}
public void optimizeTopicVectors()
{
System.out.println("\t\tEstimating topic vectors ...");
sumExpValues = new double[numTopics];
dotProductValues = new double[numTopics][vocabularySize];
expDotProductValues = new double[numTopics][vocabularySize];
Parallel.loop(numTopics, new Parallel.LoopInt()
{
@Override
public void compute(int topic)
{
int rate = 1;
boolean check = true;
while (check) {
double l2Value = l2Regularizer * rate;
try {
TopicVectorOptimizer optimizer = new TopicVectorOptimizer(
topicVectors[topic], topicWordCountLF[topic],
wordVectors, l2Value);
Optimizer gd = new LBFGS(optimizer, tolerance);
gd.optimize(600);
optimizer.getParameters(topicVectors[topic]);
sumExpValues[topic] = optimizer
.computePartitionFunction(dotProductValues[topic],
expDotProductValues[topic]);
check = false;
if (sumExpValues[topic] == 0
|| Double.isInfinite(sumExpValues[topic])) {
double max = -1000000000.0;
for (int index = 0; index < vocabularySize; index++) {
if (dotProductValues[topic][index] > max)
max = dotProductValues[topic][index];
}
for (int index = 0; index < vocabularySize; index++) {
expDotProductValues[topic][index] = Math
.exp(dotProductValues[topic][index] - max);
sumExpValues[topic] += expDotProductValues[topic][index];
}
}
}
catch (InvalidOptimizableException e) {
e.printStackTrace();
check = true;
}
rate = rate * 10;
}
}
});
}
public void sampleSingleIteration()
{
for (int dIndex = 0; dIndex < numDocuments; dIndex++) {
List<Integer> document = corpus.get(dIndex);
int docSize = document.size();
int topic = topicAssignments.get(dIndex).get(0) % numTopics;
docTopicCount[topic] = docTopicCount[topic] - 1;
for (int wIndex = 0; wIndex < docSize; wIndex++) {
int word = document.get(wIndex);// wordId
int subtopic = topicAssignments.get(dIndex).get(wIndex);
if (subtopic == topic) {
topicWordCountLF[topic][word] -= 1;
sumTopicWordCountLF[topic] -= 1;
}
else {
topicWordCountDMM[topic][word] -= 1;
sumTopicWordCountDMM[topic] -= 1;
}
}
// Sample a topic
for (int tIndex = 0; tIndex < numTopics; tIndex++) {
multiPros[tIndex] = (docTopicCount[tIndex] + alpha);
for (int wIndex = 0; wIndex < docSize; wIndex++) {
int word = document.get(wIndex);
multiPros[tIndex] *= (lambda
* expDotProductValues[tIndex][word]
/ sumExpValues[tIndex] + (1 - lambda)
* (topicWordCountDMM[tIndex][word] + beta)
/ (sumTopicWordCountDMM[tIndex] + betaSum));
}
}
topic = FuncUtils.nextDiscrete(multiPros);
docTopicCount[topic] += 1;
for (int wIndex = 0; wIndex < docSize; wIndex++) {
int word = document.get(wIndex);
int subtopic = topic;
if (lambda * expDotProductValues[topic][word]
/ sumExpValues[topic] > (1 - lambda)
* (topicWordCountDMM[topic][word] + beta)
/ (sumTopicWordCountDMM[topic] + betaSum)) {
topicWordCountLF[topic][word] += 1;
sumTopicWordCountLF[topic] += 1;
}
else {
topicWordCountDMM[topic][word] += 1;
sumTopicWordCountDMM[topic] += 1;
subtopic += numTopics;
}
// Update topic assignments
topicAssignments.get(dIndex).set(wIndex, subtopic);
}
}
}
public void sampleSingleInitialIteration()
{
for (int dIndex = 0; dIndex < numDocuments; dIndex++) {
List<Integer> document = corpus.get(dIndex);
int docSize = document.size();
int topic = topicAssignments.get(dIndex).get(0) % numTopics;
docTopicCount[topic] = docTopicCount[topic] - 1;
for (int wIndex = 0; wIndex < docSize; wIndex++) {
int word = document.get(wIndex);
int subtopic = topicAssignments.get(dIndex).get(wIndex);
if (topic == subtopic) {
topicWordCountLF[topic][word] -= 1;
sumTopicWordCountLF[topic] -= 1;
}
else {
topicWordCountDMM[topic][word] -= 1;
sumTopicWordCountDMM[topic] -= 1;
}
}
// Sample a topic
for (int tIndex = 0; tIndex < numTopics; tIndex++) {
multiPros[tIndex] = (docTopicCount[tIndex] + alpha);
for (int wIndex = 0; wIndex < docSize; wIndex++) {
int word = document.get(wIndex);
multiPros[tIndex] *= (lambda
* (topicWordCountLF[tIndex][word] + beta)
/ (sumTopicWordCountLF[tIndex] + betaSum) + (1 - lambda)
* (topicWordCountDMM[tIndex][word] + beta)
/ (sumTopicWordCountDMM[tIndex] + betaSum));
}
}
topic = FuncUtils.nextDiscrete(multiPros);
docTopicCount[topic] += 1;
for (int wIndex = 0; wIndex < docSize; wIndex++) {
int word = document.get(wIndex);// wordID
int subtopic = topic;
if (lambda * (topicWordCountLF[topic][word] + beta)
/ (sumTopicWordCountLF[topic] + betaSum) > (1 - lambda)
* (topicWordCountDMM[topic][word] + beta)
/ (sumTopicWordCountDMM[topic] + betaSum)) {
topicWordCountLF[topic][word] += 1;
sumTopicWordCountLF[topic] += 1;
}
else {
topicWordCountDMM[topic][word] += 1;
sumTopicWordCountDMM[topic] += 1;
subtopic += numTopics;
}
// Update topic assignments
topicAssignments.get(dIndex).set(wIndex, subtopic);
}
}
}
public void writeParameters()
throws IOException
{
BufferedWriter writer = new BufferedWriter(new FileWriter(folderPath
+ expName + ".paras"));
writer.write("-model" + "\t" + "LFDMM");
writer.write("\n-corpus" + "\t" + corpusPath);
writer.write("\n-vectors" + "\t" + vectorFilePath);
writer.write("\n-ntopics" + "\t" + numTopics);
writer.write("\n-alpha" + "\t" + alpha);
writer.write("\n-beta" + "\t" + beta);
writer.write("\n-lambda" + "\t" + lambda);
writer.write("\n-initers" + "\t" + numInitIterations);
writer.write("\n-niters" + "\t" + numIterations);
writer.write("\n-twords" + "\t" + topWords);
writer.write("\n-name" + "\t" + expName);
if (savestep > 0)
writer.write("\n-sstep" + "\t" + savestep);
writer.close();
}
public void writeDictionary()
throws IOException
{
BufferedWriter writer = new BufferedWriter(new FileWriter(folderPath
+ expName + ".vocabulary"));
for (String word : word2IdVocabulary.keySet()) {
writer.write(word + " " + word2IdVocabulary.get(word) + "\n");
}
writer.close();
}
public void writeIDbasedCorpus()
throws IOException
{
BufferedWriter writer = new BufferedWriter(new FileWriter(folderPath
+ expName + ".IDcorpus"));
for (int dIndex = 0; dIndex < numDocuments; dIndex++) {
int docSize = corpus.get(dIndex).size();
for (int wIndex = 0; wIndex < docSize; wIndex++) {
writer.write(corpus.get(dIndex).get(wIndex) + " ");
}
writer.write("\n");
}
writer.close();
}
public void writeTopicAssignments()
throws IOException
{
BufferedWriter writer = new BufferedWriter(new FileWriter(folderPath
+ expName + ".topicAssignments"));
for (int dIndex = 0; dIndex < numDocuments; dIndex++) {
int docSize = corpus.get(dIndex).size();
for (int wIndex = 0; wIndex < docSize; wIndex++) {
writer.write(topicAssignments.get(dIndex).get(wIndex) + " ");
}
writer.write("\n");
}
writer.close();
}
public void writeTopicVectors()
throws IOException
{
BufferedWriter writer = new BufferedWriter(new FileWriter(folderPath
+ expName + ".topicVectors"));
for (int i = 0; i < numTopics; i++) {
for (int j = 0; j < vectorSize; j++)
writer.write(topicVectors[i][j] + " ");
writer.write("\n");
}
writer.close();
}
public void writeTopTopicalWords()
throws IOException
{
BufferedWriter writer = new BufferedWriter(new FileWriter(folderPath
+ expName + ".topWords"));
for (int tIndex = 0; tIndex < numTopics; tIndex++) {
writer.write("Topic" + new Integer(tIndex) + ":");
Map<Integer, Double> topicWordProbs = new TreeMap<Integer, Double>();
for (int wIndex = 0; wIndex < vocabularySize; wIndex++) {
double pro = lambda * expDotProductValues[tIndex][wIndex]
/ sumExpValues[tIndex] + (1 - lambda)
* (topicWordCountDMM[tIndex][wIndex] + beta)
/ (sumTopicWordCountDMM[tIndex] + betaSum);
topicWordProbs.put(wIndex, pro);
}
topicWordProbs = FuncUtils.sortByValueDescending(topicWordProbs);
Set<Integer> mostLikelyWords = topicWordProbs.keySet();
int count = 0;
for (Integer index : mostLikelyWords) {
if (count < topWords) {
writer.write(" " + id2WordVocabulary.get(index));
count += 1;
}
else {
writer.write("\n\n");
break;
}
}
}
writer.close();
}
public void writeTopicWordPros()
throws IOException
{
BufferedWriter writer = new BufferedWriter(new FileWriter(folderPath
+ expName + ".phi"));
for (int t = 0; t < numTopics; t++) {
for (int w = 0; w < vocabularySize; w++) {
double pro = lambda * expDotProductValues[t][w]
/ sumExpValues[t] + (1 - lambda)
* (topicWordCountDMM[t][w] + beta)
/ (sumTopicWordCountDMM[t] + betaSum);
writer.write(pro + " ");
}
writer.write("\n");
}
writer.close();
}
public void writeDocTopicPros()
throws IOException
{
BufferedWriter writer = new BufferedWriter(new FileWriter(folderPath
+ expName + ".theta"));
for (int i = 0; i < numDocuments; i++) {
int docSize = corpus.get(i).size();
double sum = 0.0;
for (int tIndex = 0; tIndex < numTopics; tIndex++) {
multiPros[tIndex] = (docTopicCount[tIndex] + alpha);
for (int wIndex = 0; wIndex < docSize; wIndex++) {
int word = corpus.get(i).get(wIndex);
multiPros[tIndex] *= (lambda
* expDotProductValues[tIndex][word]
/ sumExpValues[tIndex] + (1 - lambda)
* (topicWordCountDMM[tIndex][word] + beta)
/ (sumTopicWordCountDMM[tIndex] + betaSum));
}
sum += multiPros[tIndex];
}
for (int tIndex = 0; tIndex < numTopics; tIndex++) {
writer.write((multiPros[tIndex] / sum) + " ");
}
writer.write("\n");
}
writer.close();
}
public void write()
throws IOException
{
writeTopTopicalWords();
writeDocTopicPros();
writeTopicAssignments();
writeTopicWordPros();
}
public static void main(String args[])
throws Exception
{
LFDMM_Inf lfdmm = new LFDMM_Inf("test/testLFDMM.paras",
"test/corpus_test.txt", 2000, 200, 20, "testLFDMMinf", 0);
lfdmm.inference();
}
}
================================================
FILE: src/models/LFLDA.java
================================================
package models;
import java.io.BufferedReader;
import java.io.BufferedWriter;
import java.io.FileReader;
import java.io.FileWriter;
import java.io.IOException;
import java.util.ArrayList;
import java.util.HashMap;
import java.util.List;
import java.util.Map;
import java.util.Set;
import java.util.TreeMap;
import utility.FuncUtils;
import utility.LBFGS;
import utility.Parallel;
import cc.mallet.optimize.InvalidOptimizableException;
import cc.mallet.optimize.Optimizer;
import cc.mallet.types.MatrixOps;
/**
* Implementation of the LF-LDA latent feature topic model, using collapsed Gibbs sampling, as
* described in:
*
* Dat Quoc Nguyen, Richard Billingsley, Lan Du and Mark Johnson. 2015. Improving Topic Models with
* Latent Feature Word Representations. Transactions of the Association for Computational
* Linguistics, vol. 3, pp. 299-313.
*
* @author Dat Quoc Nguyen
*/
public class LFLDA
{
public double alpha; // Hyper-parameter alpha
public double beta; // Hyper-parameter alpha
public double alphaSum; // alpha * numTopics
public double betaSum; // beta * vocabularySize
public int numTopics; // Number of topics
public int topWords; // Number of most probable words for each topic
public double lambda; // Mixture weight value
public int numInitIterations;
public int numIterations; // Number of EM-style sampling iterations
public List<List<Integer>> corpus; // Word ID-based corpus
public List<List<Integer>> topicAssignments; // Topics assignments for words
// in the corpus
public int numDocuments; // Number of documents in the corpus
public int numWordsInCorpus; // Number of words in the corpus
public HashMap<String, Integer> word2IdVocabulary; // Vocabulary to get ID
// given a word
public HashMap<Integer, String> id2WordVocabulary; // Vocabulary to get word
// given an ID
public int vocabularySize; // The number of word types in the corpus
// numDocuments * numTopics matrix
// Given a document: number of its words assigned to each topic
public int[][] docTopicCount;
// Number of words in every document
public int[] sumDocTopicCount;
// numTopics * vocabularySize matrix
// Given a topic: number of times a word type generated from the topic by
// the Dirichlet multinomial component
public int[][] topicWordCountLDA;
// Total number of words generated from each topic by the Dirichlet
// multinomial component
public int[] sumTopicWordCountLDA;
// numTopics * vocabularySize matrix
// Given a topic: number of times a word type generated from the topic by
// the latent feature component
public int[][] topicWordCountLF;
// Total number of words generated from each topic by the latent feature
// component
public int[] sumTopicWordCountLF;
// Double array used to sample a topic
public double[] multiPros;
// Path to the directory containing the corpus
public String folderPath;
// Path to the topic modeling corpus
public String corpusPath;
public String vectorFilePath;
public double[][] wordVectors; // Vector representations for words
public double[][] topicVectors;// Vector representations for topics
public int vectorSize; // Number of vector dimensions
public double[][] dotProductValues;
public double[][] expDotProductValues;
public double[] sumExpValues; // Partition function values
public final double l2Regularizer = 0.01; // L2 regularizer value for learning topic vectors
public final double tolerance = 0.05; // Tolerance value for LBFGS convergence
public String expName = "LFLDA";
public String orgExpName = "LFLDA";
public String tAssignsFilePath = "";
public int savestep = 0;
public LFLDA(String pathToCorpus, String pathToWordVectorsFile, int inNumTopics,
double inAlpha, double inBeta, double inLambda, int inNumInitIterations,
int inNumIterations, int inTopWords)
throws Exception
{
this(pathToCorpus, pathToWordVectorsFile, inNumTopics, inAlpha, inBeta, inLambda,
inNumInitIterations, inNumIterations, inTopWords, "LFLDA");
}
public LFLDA(String pathToCorpus, String pathToWordVectorsFile, int inNumTopics,
double inAlpha, double inBeta, double inLambda, int inNumInitIterations,
int inNumIterations, int inTopWords, String inExpName)
throws Exception
{
this(pathToCorpus, pathToWordVectorsFile, inNumTopics, inAlpha, inBeta, inLambda,
inNumInitIterations, inNumIterations, inTopWords, inExpName, "", 0);
}
public LFLDA(String pathToCorpus, String pathToWordVectorsFile, int inNumTopics,
double inAlpha, double inBeta, double inLambda, int inNumInitIterations,
int inNumIterations, int inTopWords, String inExpName, String pathToTAfile)
throws Exception
{
this(pathToCorpus, pathToWordVectorsFile, inNumTopics, inAlpha, inBeta, inLambda,
inNumInitIterations, inNumIterations, inTopWords, inExpName, pathToTAfile, 0);
}
public LFLDA(String pathToCorpus, String pathToWordVectorsFile, int inNumTopics,
double inAlpha, double inBeta, double inLambda, int inNumInitIterations,
int inNumIterations, int inTopWords, String inExpName, int inSaveStep)
throws Exception
{
this(pathToCorpus, pathToWordVectorsFile, inNumTopics, inAlpha, inBeta, inLambda,
inNumInitIterations, inNumIterations, inTopWords, inExpName, "", inSaveStep);
}
public LFLDA(String pathToCorpus, String pathToWordVectorsFile, int inNumTopics,
double inAlpha, double inBeta, double inLambda, int inNumInitIterations,
int inNumIterations, int inTopWords, String inExpName, String pathToTAfile,
int inSaveStep)
throws Exception
{
alpha = inAlpha;
beta = inBeta;
lambda = inLambda;
numTopics = inNumTopics;
numIterations = inNumIterations;
numInitIterations = inNumInitIterations;
topWords = inTopWords;
savestep = inSaveStep;
expName = inExpName;
orgExpName = expName;
vectorFilePath = pathToWordVectorsFile;
corpusPath = pathToCorpus;
folderPath = pathToCorpus.substring(0,
Math.max(pathToCorpus.lastIndexOf("/"), pathToCorpus.lastIndexOf("\\")) + 1);
System.out.println("Reading topic modeling corpus: " + pathToCorpus);
word2IdVocabulary = new HashMap<String, Integer>();
id2WordVocabulary = new HashMap<Integer, String>();
corpus = new ArrayList<List<Integer>>();
numDocuments = 0;
numWordsInCorpus = 0;
BufferedReader br = null;
try {
int indexWord = -1;
br = new BufferedReader(new FileReader(pathToCorpus));
for (String doc; (doc = br.readLine()) != null;) {
if (doc.trim().length() == 0)
continue;
String[] words = doc.trim().split("\\s+");
List<Integer> document = new ArrayList<Integer>();
for (String word : words) {
if (word2IdVocabulary.containsKey(word)) {
document.add(word2IdVocabulary.get(word));
}
else {
indexWord += 1;
word2IdVocabulary.put(word, indexWord);
id2WordVocabulary.put(indexWord, word);
document.add(indexWord);
}
}
numDocuments++;
numWordsInCorpus += document.size();
corpus.add(document);
}
}
catch (Exception e) {
e.printStackTrace();
}
vocabularySize = word2IdVocabulary.size();
docTopicCount = new int[numDocuments][numTopics];
sumDocTopicCount = new int[numDocuments];
topicWordCountLDA = new int[numTopics][vocabularySize];
sumTopicWordCountLDA = new int[numTopics];
topicWordCountLF = new int[numTopics][vocabularySize];
sumTopicWordCountLF = new int[numTopics];
multiPros = new double[numTopics * 2];
for (int i = 0; i < numTopics * 2; i++) {
multiPros[i] = 1.0 / numTopics;
}
alphaSum = numTopics * alpha;
betaSum = vocabularySize * beta;
readWordVectorsFile(vectorFilePath);
topicVectors = new double[numTopics][vectorSize];
dotProductValues = new double[numTopics][vocabularySize];
expDotProductValues = new double[numTopics][vocabularySize];
sumExpValues = new double[numTopics];
System.out
.println("Corpus size: " + numDocuments + " docs, " + numWordsInCorpus + " words");
System.out.println("Vocabuary size: " + vocabularySize);
System.out.println("Number of topics: " + numTopics);
System.out.println("alpha: " + alpha);
System.out.println("beta: " + beta);
System.out.println("lambda: " + lambda);
System.out.println("Number of initial sampling iterations: " + numInitIterations);
System.out.println("Number of EM-style sampling iterations for the LF-LDA model: "
+ numIterations);
System.out.println("Number of top topical words: " + topWords);
tAssignsFilePath = pathToTAfile;
if (tAssignsFilePath.length() > 0)
initialize(tAssignsFilePath);
else
initialize();
}
public void readWordVectorsFile(String pathToWordVectorsFile)
throws Exception
{
System.out.println("Reading word vectors from word-vectors file " + pathToWordVectorsFile
+ "...");
BufferedReader br = null;
try {
br = new BufferedReader(new FileReader(pathToWordVectorsFile));
String[] elements = br.readLine().trim().split("\\s+");
vectorSize = elements.length - 1;
wordVectors = new double[vocabularySize][vectorSize];
String word = elements[0];
if (word2IdVocabulary.containsKey(word)) {
for (int j = 0; j < vectorSize; j++) {
wordVectors[word2IdVocabulary.get(word)][j] = new Double(elements[j + 1]);
}
}
for (String line; (line = br.readLine()) != null;) {
elements = line.trim().split("\\s+");
word = elements[0];
if (word2IdVocabulary.containsKey(word)) {
for (int j = 0; j < vectorSize; j++) {
wordVectors[word2IdVocabulary.get(word)][j] = new Double(elements[j + 1]);
}
}
}
}
catch (Exception e) {
e.printStackTrace();
}
for (int i = 0; i < vocabularySize; i++) {
if (MatrixOps.absNorm(wordVectors[i]) == 0.0) {
System.out.println("The word \"" + id2WordVocabulary.get(i)
+ "\" doesn't have a corresponding vector!!!");
throw new Exception();
}
}
}
public void initialize()
throws IOException
{
System.out.println("Randomly initialzing topic assignments ...");
topicAssignments = new ArrayList<List<Integer>>();
for (int docId = 0; docId < numDocuments; docId++) {
List<Integer> topics = new ArrayList<Integer>();
int docSize = corpus.get(docId).size();
for (int j = 0; j < docSize; j++) {
int wordId = corpus.get(docId).get(j);
int subtopic = FuncUtils.nextDiscrete(multiPros);
int topic = subtopic % numTopics;
if (topic == subtopic) { // Generated from the latent feature component
topicWordCountLF[topic][wordId] += 1;
sumTopicWordCountLF[topic] += 1;
}
else {// Generated from the Dirichlet multinomial component
topicWordCountLDA[topic][wordId] += 1;
sumTopicWordCountLDA[topic] += 1;
}
docTopicCount[docId][topic] += 1;
sumDocTopicCount[docId] += 1;
topics.add(subtopic);
}
topicAssignments.add(topics);
}
}
public void initialize(String pathToTopicAssignmentFile)
{
System.out.println("Reading topic-assignment file: " + pathToTopicAssignmentFile);
topicAssignments = new ArrayList<List<Integer>>();
BufferedReader br = null;
try {
br = new BufferedReader(new FileReader(pathToTopicAssignmentFile));
int docId = 0;
int numWords = 0;
for (String line; (line = br.readLine()) != null;) {
String[] strTopics = line.trim().split("\\s+");
List<Integer> topics = new ArrayList<Integer>();
for (int j = 0; j < strTopics.length; j++) {
int wordId = corpus.get(docId).get(j);
int subtopic = new Integer(strTopics[j]);
int topic = subtopic % numTopics;
if (topic == subtopic) { // Generated from the latent feature component
topicWordCountLF[topic][wordId] += 1;
sumTopicWordCountLF[topic] += 1;
}
else {// Generated from the Dirichlet multinomial component
topicWordCountLDA[topic][wordId] += 1;
sumTopicWordCountLDA[topic] += 1;
}
docTopicCount[docId][topic] += 1;
sumDocTopicCount[docId] += 1;
topics.add(subtopic);
numWords++;
}
topicAssignments.add(topics);
docId++;
}
if ((docId != numDocuments) || (numWords != numWordsInCorpus)) {
System.out
.println("The topic modeling corpus and topic assignment file are not consistent!!!");
throw new Exception();
}
}
catch (Exception e) {
e.printStackTrace();
}
}
public void inference()
throws IOException
{
System.out.println("Running Gibbs sampling inference: ");
for (int iter = 1; iter <= numInitIterations; iter++) {
System.out.println("\tInitial sampling iteration: " + (iter));
sampleSingleInitialIteration();
}
for (int iter = 1; iter <= numIterations; iter++) {
System.out.println("\tLFLDA sampling iteration: " + (iter));
optimizeTopicVectors();
sampleSingleIteration();
if ((savestep > 0) && (iter % savestep == 0) && (iter < numIterations)) {
System.out.println("\t\tSaving the output from the " + iter + "^{th} sample");
expName = orgExpName + "-" + iter;
write();
}
}
expName = orgExpName;
writeParameters();
System.out.println("Writing output from the last sample ...");
write();
System.out.println("Sampling completed!");
}
public void optimizeTopicVectors()
{
System.out.println("\t\tEstimating topic vectors ...");
sumExpValues = new double[numTopics];
dotProductValues = new double[numTopics][vocabularySize];
expDotProductValues = new double[numTopics][vocabularySize];
Parallel.loop(numTopics, new Parallel.LoopInt()
{
@Override
public void compute(int topic)
{
int rate = 1;
boolean check = true;
while (check) {
double l2Value = l2Regularizer * rate;
try {
TopicVectorOptimizer optimizer = new TopicVectorOptimizer(
topicVectors[topic], topicWordCountLF[topic], wordVectors, l2Value);
Optimizer gd = new LBFGS(optimizer, tolerance);
gd.optimize(600);
optimizer.getParameters(topicVectors[topic]);
sumExpValues[topic] = optimizer.computePartitionFunction(
dotProductValues[topic], expDotProductValues[topic]);
check = false;
if (sumExpValues[topic] == 0 || Double.isInfinite(sumExpValues[topic])) {
double max = -1000000000.0;
for (int index = 0; index < vocabularySize; index++) {
if (dotProductValues[topic][index] > max)
max = dotProductValues[topic][index];
}
for (int index = 0; index < vocabularySize; index++) {
expDotProductValues[topic][index] = Math
.exp(dotProductValues[topic][index] - max);
sumExpValues[topic] += expDotProductValues[topic][index];
}
}
}
catch (InvalidOptimizableException e) {
e.printStackTrace();
check = true;
}
rate = rate * 10;
}
}
});
}
public void sampleSingleIteration()
{
for (int dIndex = 0; dIndex < numDocuments; dIndex++) {
int docSize = corpus.get(dIndex).size();
for (int wIndex = 0; wIndex < docSize; wIndex++) {
// Get current word
int word = corpus.get(dIndex).get(wIndex);// wordID
int subtopic = topicAssignments.get(dIndex).get(wIndex);
int topic = subtopic % numTopics;
docTopicCount[dIndex][topic] -= 1;
if (subtopic == topic) {
topicWordCountLF[topic][word] -= 1;
sumTopicWordCountLF[topic] -= 1;
}
else {
topicWordCountLDA[topic][word] -= 1;
sumTopicWordCountLDA[topic] -= 1;
}
// Sample a pair of topic z and binary indicator variable s
for (int tIndex = 0; tIndex < numTopics; tIndex++) {
multiPros[tIndex] = (docTopicCount[dIndex][tIndex] + alpha) * lambda
* expDotProductValues[tIndex][word] / sumExpValues[tIndex];
multiPros[tIndex + numTopics] = (docTopicCount[dIndex][tIndex] + alpha)
* (1 - lambda) * (topicWordCountLDA[tIndex][word] + beta)
/ (sumTopicWordCountLDA[tIndex] + betaSum);
}
subtopic = FuncUtils.nextDiscrete(multiPros);
topic = subtopic % numTopics;
docTopicCount[dIndex][topic] += 1;
if (subtopic == topic) {
topicWordCountLF[topic][word] += 1;
sumTopicWordCountLF[topic] += 1;
}
else {
topicWordCountLDA[topic][word] += 1;
sumTopicWordCountLDA[topic] += 1;
}
// Update topic assignments
topicAssignments.get(dIndex).set(wIndex, subtopic);
}
}
}
public void sampleSingleInitialIteration()
{
for (int dIndex = 0; dIndex < numDocuments; dIndex++) {
int docSize = corpus.get(dIndex).size();
for (int wIndex = 0; wIndex < docSize; wIndex++) {
int word = corpus.get(dIndex).get(wIndex);// wordID
int subtopic = topicAssignments.get(dIndex).get(wIndex);
int topic = subtopic % numTopics;
docTopicCount[dIndex][topic] -= 1;
if (subtopic == topic) { // LF(w|t) + LDA(t|d)
topicWordCountLF[topic][word] -= 1;
sumTopicWordCountLF[topic] -= 1;
}
else { // LDA(w|t) + LDA(t|d)
topicWordCountLDA[topic][word] -= 1;
sumTopicWordCountLDA[topic] -= 1;
}
// Sample a pair of topic z and binary indicator variable s
for (int tIndex = 0; tIndex < numTopics; tIndex++) {
multiPros[tIndex] = (docTopicCount[dIndex][tIndex] + alpha) * lambda
* (topicWordCountLF[tIndex][word] + beta)
/ (sumTopicWordCountLF[tIndex] + betaSum);
multiPros[tIndex + numTopics] = (docTopicCount[dIndex][tIndex] + alpha)
* (1 - lambda) * (topicWordCountLDA[tIndex][word] + beta)
/ (sumTopicWordCountLDA[tIndex] + betaSum);
}
subtopic = FuncUtils.nextDiscrete(multiPros);
topic = subtopic % numTopics;
docTopicCount[dIndex][topic] += 1;
if (topic == subtopic) {
topicWordCountLF[topic][word] += 1;
sumTopicWordCountLF[topic] += 1;
}
else {
topicWordCountLDA[topic][word] += 1;
sumTopicWordCountLDA[topic] += 1;
}
// Update topic assignments
topicAssignments.get(dIndex).set(wIndex, subtopic);
}
}
}
public void writeParameters()
throws IOException
{
BufferedWriter writer = new BufferedWriter(new FileWriter(folderPath + expName + ".paras"));
writer.write("-model" + "\t" + "LFLDA");
writer.write("\n-corpus" + "\t" + corpusPath);
writer.write("\n-vectors" + "\t" + vectorFilePath);
writer.write("\n-ntopics" + "\t" + numTopics);
writer.write("\n-alpha" + "\t" + alpha);
writer.write("\n-beta" + "\t" + beta);
writer.write("\n-lambda" + "\t" + lambda);
writer.write("\n-initers" + "\t" + numInitIterations);
writer.write("\n-niters" + "\t" + numIterations);
writer.write("\n-twords" + "\t" + topWords);
writer.write("\n-name" + "\t" + expName);
if (tAssignsFilePath.length() > 0)
writer.write("\n-initFile" + "\t" + tAssignsFilePath);
if (savestep > 0)
writer.write("\n-sstep" + "\t" + savestep);
writer.close();
}
public void writeDictionary()
throws IOException
{
BufferedWriter writer = new BufferedWriter(new FileWriter(folderPath + expName
+ ".vocabulary"));
for (String word : word2IdVocabulary.keySet()) {
writer.write(word + " " + word2IdVocabulary.get(word) + "\n");
}
writer.close();
}
public void writeIDbasedCorpus()
throws IOException
{
BufferedWriter writer = new BufferedWriter(new FileWriter(folderPath + expName
+ ".IDcorpus"));
for (int dIndex = 0; dIndex < numDocuments; dIndex++) {
int docSize = corpus.get(dIndex).size();
for (int wIndex = 0; wIndex < docSize; wIndex++) {
writer.write(corpus.get(dIndex).get(wIndex) + " ");
}
writer.write("\n");
}
writer.close();
}
public void writeTopicAssignments()
throws IOException
{
BufferedWriter writer = new BufferedWriter(new FileWriter(folderPath + expName
+ ".topicAssignments"));
for (int dIndex = 0; dIndex < numDocuments; dIndex++) {
int docSize = corpus.get(dIndex).size();
for (int wIndex = 0; wIndex < docSize; wIndex++) {
writer.write(topicAssignments.get(dIndex).get(wIndex) + " ");
}
writer.write("\n");
}
writer.close();
}
public void writeTopicVectors()
throws IOException
{
BufferedWriter writer = new BufferedWriter(new FileWriter(folderPath + expName
+ ".topicVectors"));
for (int i = 0; i < numTopics; i++) {
for (int j = 0; j < vectorSize; j++)
writer.write(topicVectors[i][j] + " ");
writer.write("\n");
}
writer.close();
}
public void writeTopTopicalWords()
throws IOException
{
BufferedWriter writer = new BufferedWriter(new FileWriter(folderPath + expName
+ ".topWords"));
for (int tIndex = 0; tIndex < numTopics; tIndex++) {
writer.write("Topic" + new Integer(tIndex) + ":");
Map<Integer, Double> topicWordProbs = new TreeMap<Integer, Double>();
for (int wIndex = 0; wIndex < vocabularySize; wIndex++) {
double pro = lambda * expDotProductValues[tIndex][wIndex] / sumExpValues[tIndex]
+ (1 - lambda) * (topicWordCountLDA[tIndex][wIndex] + beta)
/ (sumTopicWordCountLDA[tIndex] + betaSum);
topicWordProbs.put(wIndex, pro);
}
topicWordProbs = FuncUtils.sortByValueDescending(topicWordProbs);
Set<Integer> mostLikelyWords = topicWordProbs.keySet();
int count = 0;
for (Integer index : mostLikelyWords) {
if (count < topWords) {
writer.write(" " + id2WordVocabulary.get(index));
count += 1;
}
else {
writer.write("\n\n");
break;
}
}
}
writer.close();
}
public void writeTopicWordPros()
throws IOException
{
BufferedWriter writer = new BufferedWriter(new FileWriter(folderPath + expName + ".phi"));
for (int t = 0; t < numTopics; t++) {
for (int w = 0; w < vocabularySize; w++) {
double pro = lambda * expDotProductValues[t][w] / sumExpValues[t] + (1 - lambda)
* (topicWordCountLDA[t][w] + beta) / (sumTopicWordCountLDA[t] + betaSum);
writer.write(pro + " ");
}
writer.write("\n");
}
writer.close();
}
public void writeDocTopicPros()
throws IOException
{
BufferedWriter writer = new BufferedWriter(new FileWriter(folderPath + expName + ".theta"));
for (int i = 0; i < numDocuments; i++) {
for (int j = 0; j < numTopics; j++) {
double pro = (docTopicCount[i][j] + alpha) / (sumDocTopicCount[i] + alphaSum);
writer.write(pro + " ");
}
writer.write("\n");
}
writer.close();
}
public void write()
throws IOException
{
writeTopTopicalWords();
writeDocTopicPros();
writeTopicAssignments();
writeTopicWordPros();
}
public static void main(String args[])
throws Exception
{
LFLDA lflda = new LFLDA("test/corpus.txt", "test/wordVectors.txt", 4, 0.1, 0.01, 0.6, 2000,
200, 20, "testLFLDA");
lflda.writeParameters();
lflda.inference();
}
}
================================================
FILE: src/models/LFLDA_Inf.java
================================================
package models;
import java.io.BufferedReader;
import java.io.BufferedWriter;
import java.io.FileReader;
import java.io.FileWriter;
import java.io.IOException;
import java.util.ArrayList;
import java.util.HashMap;
import java.util.List;
import java.util.Map;
import java.util.Set;
import java.util.TreeMap;
import utility.FuncUtils;
import utility.LBFGS;
import utility.Parallel;
import cc.mallet.optimize.InvalidOptimizableException;
import cc.mallet.optimize.Optimizer;
import cc.mallet.types.MatrixOps;
/**
* Implementation of the LF-LDA latent feature topic model, using collapsed
* Gibbs sampling, as described in:
*
* Dat Quoc Nguyen, Richard Billingsley, Lan Du and Mark Johnson. 2015.
* Improving Topic Models with Latent Feature Word Representations. Transactions
* of the Association for Computational Linguistics, vol. 3, pp. 299-313.
*
* Inference of topic distribution on unseen corpus
*
* @author Dat Quoc Nguyen
*/
public class LFLDA_Inf
{
public double alpha; // Hyper-parameter alpha
public double beta; // Hyper-parameter alpha
public double alphaSum; // alpha * numTopics
public double betaSum; // beta * vocabularySize
public int numTopics; // Number of topics
public int topWords; // Number of most probable words for each topic
public double lambda; // Mixture weight value
public int numInitIterations;
public int numIterations; // Number of EM-style sampling iterations
public List<List<Integer>> corpus; // Word ID-based corpus
public List<List<Integer>> topicAssignments; // Topics assignments for words
// in the corpus
public int numDocuments; // Number of documents in the corpus
public int numWordsInCorpus; // Number of words in the corpus
public HashMap<String, Integer> word2IdVocabulary; // Vocabulary to get ID
// given a word
public HashMap<Integer, String> id2WordVocabulary; // Vocabulary to get word
// given an ID
public int vocabularySize; // The number of word types in the corpus
// numDocuments * numTopics matrix
// Given a document: number of its words assigned to each topic
public int[][] docTopicCount;
// Number of words in every document
public int[] sumDocTopicCount;
// numTopics * vocabularySize matrix
// Given a topic: number of times a word type generated from the topic by
// the Dirichlet multinomial component
public int[][] topicWordCountLDA;
// Total number of words generated from each topic by the Dirichlet
// multinomial component
public int[] sumTopicWordCountLDA;
// numTopics * vocabularySize matrix
// Given a topic: number of times a word type generated from the topic by
// the latent feature component
public int[][] topicWordCountLF;
// Total number of words generated from each topic by the latent feature
// component
public int[] sumTopicWordCountLF;
// Double array used to sample a topic
public double[] multiPros;
// Path to the directory containing the corpus
public String folderPath;
// Path to the topic modeling corpus
public String corpusPath;
public String vectorFilePath;
public double[][] wordVectors; // Vector representations for words
public double[][] topicVectors;// Vector representations for topics
public int vectorSize; // Number of vector dimensions
public double[][] dotProductValues;
public double[][] expDotProductValues;
public double[] sumExpValues; // Partition function values
public final double l2Regularizer = 0.01; // L2 regularizer value for
// learning topic vectors
public final double tolerance = 0.05; // Tolerance value for LBFGS
// convergence
public String expName = "LFLDAinf";
public String orgExpName = "LFLDAinf";
public int savestep = 0;
public LFLDA_Inf(String pathToTrainingParasFile, String pathToUnseenCorpus,
int inNumInitIterations, int inNumIterations, int inTopWords,
String inExpName, int inSaveStep)
throws Exception
{
HashMap<String, String> paras = parseTrainingParasFile(pathToTrainingParasFile);
if (!paras.get("-model").equals("LFLDA")) {
throw new Exception("Wrong pre-trained model!!!");
}
alpha = new Double(paras.get("-alpha"));
beta = new Double(paras.get("-beta"));
lambda = new Double(paras.get("-lambda"));
numTopics = new Integer(paras.get("-ntopics"));
numIterations = inNumIterations;
numInitIterations = inNumInitIterations;
topWords = inTopWords;
savestep = inSaveStep;
expName = inExpName;
orgExpName = expName;
vectorFilePath = paras.get("-vectors");
String trainingCorpus = paras.get("-corpus");
String trainingCorpusfolder = trainingCorpus.substring(
0,
Math.max(trainingCorpus.lastIndexOf("/"),
trainingCorpus.lastIndexOf("\\")) + 1);
String topicAssignment4TrainFile = trainingCorpusfolder
+ paras.get("-name") + ".topicAssignments";
word2IdVocabulary = new HashMap<String, Integer>();
id2WordVocabulary = new HashMap<Integer, String>();
initializeWordCount(trainingCorpus, topicAssignment4TrainFile);
corpusPath = pathToUnseenCorpus;
folderPath = pathToUnseenCorpus.substring(
0,
Math.max(pathToUnseenCorpus.lastIndexOf("/"),
pathToUnseenCorpus.lastIndexOf("\\")) + 1);
System.out.println("Reading unseen corpus: " + pathToUnseenCorpus);
corpus = new ArrayList<List<Integer>>();
numDocuments = 0;
numWordsInCorpus = 0;
BufferedReader br = null;
try {
br = new BufferedReader(new FileReader(pathToUnseenCorpus));
for (String doc; (doc = br.readLine()) != null;) {
if (doc.trim().length() == 0)
continue;
String[] words = doc.trim().split("\\s+");
List<Integer> document = new ArrayList<Integer>();
for (String word : words) {
if (word2IdVocabulary.containsKey(word)) {
document.add(word2IdVocabulary.get(word));
}
else {
// Skip this unknown-word
}
}
numDocuments++;
numWordsInCorpus += document.size();
corpus.add(document);
}
}
catch (Exception e) {
e.printStackTrace();
}
docTopicCount = new int[numDocuments][numTopics];
sumDocTopicCount = new int[numDocuments];
multiPros = new double[numTopics * 2];
for (int i = 0; i < numTopics * 2; i++) {
multiPros[i] = 1.0 / numTopics;
}
alphaSum = numTopics * alpha;
betaSum = vocabularySize * beta;
readWordVectorsFile(vectorFilePath);
topicVectors = new double[numTopics][vectorSize];
dotProductValues = new double[numTopics][vocabularySize];
expDotProductValues = new double[numTopics][vocabularySize];
sumExpValues = new double[numTopics];
System.out.println("Corpus size: " + numDocuments + " docs, "
+ numWordsInCorpus + " words");
System.out.println("Vocabuary size: " + vocabularySize);
System.out.println("Number of topics: " + numTopics);
System.out.println("alpha: " + alpha);
System.out.println("beta: " + beta);
System.out.println("lambda: " + lambda);
System.out.println("Number of initial sampling iterations: "
+ numInitIterations);
System.out
.println("Number of EM-style sampling iterations for the LF-LDA model: "
+ numIterations);
System.out.println("Number of top topical words: " + topWords);
initialize();
}
private HashMap<String, String> parseTrainingParasFile(
String pathToTrainingParasFile)
throws Exception
{
HashMap<String, String> paras = new HashMap<String, String>();
BufferedReader br = null;
try {
br = new BufferedReader(new FileReader(pathToTrainingParasFile));
for (String line; (line = br.readLine()) != null;) {
if (line.trim().length() == 0)
continue;
String[] paraOptions = line.trim().split("\\s+");
paras.put(paraOptions[0], paraOptions[1]);
}
}
catch (Exception e) {
e.printStackTrace();
}
return paras;
}
private void initializeWordCount(String pathToTrainingCorpus,
String pathToTopicAssignmentFile)
{
System.out.println("Loading pre-trained model...");
List<List<Integer>> trainCorpus = new ArrayList<List<Integer>>();
BufferedReader br = null;
try {
int indexWord = -1;
br = new BufferedReader(new FileReader(pathToTrainingCorpus));
for (String doc; (doc = br.readLine()) != null;) {
if (doc.trim().length() == 0)
continue;
String[] words = doc.trim().split("\\s+");
List<Integer> document = new ArrayList<Integer>();
for (String word : words) {
if (word2IdVocabulary.containsKey(word)) {
document.add(word2IdVocabulary.get(word));
}
else {
indexWord += 1;
word2IdVocabulary.put(word, indexWord);
id2WordVocabulary.put(indexWord, word);
document.add(indexWord);
}
}
trainCorpus.add(document);
}
}
catch (Exception e) {
e.printStackTrace();
}
vocabularySize = word2IdVocabulary.size();
topicWordCountLDA = new int[numTopics][vocabularySize];
sumTopicWordCountLDA = new int[numTopics];
topicWordCountLF = new int[numTopics][vocabularySize];
sumTopicWordCountLF = new int[numTopics];
try {
br = new BufferedReader(new FileReader(pathToTopicAssignmentFile));
int docId = 0;
for (String line; (line = br.readLine()) != null;) {
String[] strTopics = line.trim().split("\\s+");
for (int j = 0; j < strTopics.length; j++) {
int wordId = trainCorpus.get(docId).get(j);
int subtopic = new Integer(strTopics[j]);
int topic = subtopic % numTopics;
if (topic == subtopic) { // Generated from the latent
// feature component
topicWordCountLF[topic][wordId] += 1;
sumTopicWordCountLF[topic] += 1;
}
else {// Generated from the Dirichlet multinomial component
topicWordCountLDA[topic][wordId] += 1;
sumTopicWordCountLDA[topic] += 1;
}
}
docId++;
}
}
catch (Exception e) {
e.printStackTrace();
}
}
public void readWordVectorsFile(String pathToWordVectorsFile)
throws Exception
{
System.out.println("Reading word vectors from word-vectors file "
+ pathToWordVectorsFile + "...");
BufferedReader br = null;
try {
br = new BufferedReader(new FileReader(pathToWordVectorsFile));
String[] elements = br.readLine().trim().split("\\s+");
vectorSize = elements.length - 1;
wordVectors = new double[vocabularySize][vectorSize];
String word = elements[0];
if (word2IdVocabulary.containsKey(word)) {
for (int j = 0; j < vectorSize; j++) {
wordVectors[word2IdVocabulary.get(word)][j] = new Double(
elements[j + 1]);
}
}
for (String line; (line = br.readLine()) != null;) {
elements = line.trim().split("\\s+");
word = elements[0];
if (word2IdVocabulary.containsKey(word)) {
for (int j = 0; j < vectorSize; j++) {
wordVectors[word2IdVocabulary.get(word)][j] = new Double(
elements[j + 1]);
}
}
}
}
catch (Exception e) {
e.printStackTrace();
}
for (int i = 0; i < vocabularySize; i++) {
if (MatrixOps.absNorm(wordVectors[i]) == 0.0) {
System.out.println("The word \"" + id2WordVocabulary.get(i)
+ "\" doesn't have a corresponding vector!!!");
throw new Exception();
}
}
}
public void initialize()
throws IOException
{
System.out.println("Randomly initialzing topic assignments ...");
topicAssignments = new ArrayList<List<Integer>>();
for (int docId = 0; docId < numDocuments; docId++) {
List<Integer> topics = new ArrayList<Integer>();
int docSize = corpus.get(docId).size();
for (int j = 0; j < docSize; j++) {
int wordId = corpus.get(docId).get(j);
int subtopic = FuncUtils.nextDiscrete(multiPros);
int topic = subtopic % numTopics;
if (topic == subtopic) { // Generated from the latent feature
// component
topicWordCountLF[topic][wordId] += 1;
sumTopicWordCountLF[topic] += 1;
}
else {// Generated from the Dirichlet multinomial component
topicWordCountLDA[topic][wordId] += 1;
sumTopicWordCountLDA[topic] += 1;
}
docTopicCount[docId][topic] += 1;
sumDocTopicCount[docId] += 1;
topics.add(subtopic);
}
topicAssignments.add(topics);
}
}
public void inference()
throws IOException
{
System.out.println("Running Gibbs sampling inference: ");
for (int iter = 1; iter <= numInitIterations; iter++) {
System.out.println("\tInitial sampling iteration: " + (iter));
sampleSingleInitialIteration();
}
for (int iter = 1; iter <= numIterations; iter++) {
System.out.println("\tLFLDA sampling iteration: " + (iter));
optimizeTopicVectors();
sampleSingleIteration();
if ((savestep > 0) && (iter % savestep == 0)
&& (iter < numIterations)) {
System.out.println("\t\tSaving the output from the " + iter
+ "^{th} sample");
expName = orgExpName + "-" + iter;
write();
}
}
expName = orgExpName;
writeParameters();
System.out.println("Writing output from the last sample ...");
write();
System.out.println("Sampling completed!");
}
public void optimizeTopicVectors()
{
System.out.println("\t\tEstimating topic vectors ...");
sumExpValues = new double[numTopics];
dotProductValues = new double[numTopics][vocabularySize];
expDotProductValues = new double[numTopics][vocabularySize];
Parallel.loop(numTopics, new Parallel.LoopInt()
{
@Override
public void compute(int topic)
{
int rate = 1;
boolean check = true;
while (check) {
double l2Value = l2Regularizer * rate;
try {
TopicVectorOptimizer optimizer = new TopicVectorOptimizer(
topicVectors[topic], topicWordCountLF[topic],
wordVectors, l2Value);
Optimizer gd = new LBFGS(optimizer, tolerance);
gd.optimize(600);
optimizer.getParameters(topicVectors[topic]);
sumExpValues[topic] = optimizer
.computePartitionFunction(dotProductValues[topic],
expDotProductValues[topic]);
check = false;
if (sumExpValues[topic] == 0
|| Double.isInfinite(sumExpValues[topic])) {
double max = -1000000000.0;
for (int index = 0; index < vocabularySize; index++) {
if (dotProductValues[topic][index] > max)
max = dotProductValues[topic][index];
}
for (int index = 0; index < vocabularySize; index++) {
expDotProductValues[topic][index] = Math
.exp(dotProductValues[topic][index] - max);
sumExpValues[topic] += expDotProductValues[topic][index];
}
}
}
catch (InvalidOptimizableException e) {
e.printStackTrace();
check = true;
}
rate = rate * 10;
}
}
});
}
public void sampleSingleIteration()
{
for (int dIndex = 0; dIndex < numDocuments; dIndex++) {
int docSize = corpus.get(dIndex).size();
for (int wIndex = 0; wIndex < docSize; wIndex++) {
// Get current word
int word = corpus.get(dIndex).get(wIndex);// wordID
int subtopic = topicAssignments.get(dIndex).get(wIndex);
int topic = subtopic % numTopics;
docTopicCount[dIndex][topic] -= 1;
if (subtopic == topic) {
topicWordCountLF[topic][word] -= 1;
sumTopicWordCountLF[topic] -= 1;
}
else {
topicWordCountLDA[topic][word] -= 1;
sumTopicWordCountLDA[topic] -= 1;
}
// Sample a pair of topic z and binary indicator variable s
for (int tIndex = 0; tIndex < numTopics; tIndex++) {
multiPros[tIndex] = (docTopicCount[dIndex][tIndex] + alpha)
* lambda * expDotProductValues[tIndex][word]
/ sumExpValues[tIndex];
multiPros[tIndex + numTopics] = (docTopicCount[dIndex][tIndex] + alpha)
* (1 - lambda)
* (topicWordCountLDA[tIndex][word] + beta)
/ (sumTopicWordCountLDA[tIndex] + betaSum);
}
subtopic = FuncUtils.nextDiscrete(multiPros);
topic = subtopic % numTopics;
docTopicCount[dIndex][topic] += 1;
if (subtopic == topic) {
topicWordCountLF[topic][word] += 1;
sumTopicWordCountLF[topic] += 1;
}
else {
topicWordCountLDA[topic][word] += 1;
sumTopicWordCountLDA[topic] += 1;
}
// Update topic assignments
topicAssignments.get(dIndex).set(wIndex, subtopic);
}
}
}
public void sampleSingleInitialIteration()
{
for (int dIndex = 0; dIndex < numDocuments; dIndex++) {
int docSize = corpus.get(dIndex).size();
for (int wIndex = 0; wIndex < docSize; wIndex++) {
int word = corpus.get(dIndex).get(wIndex);// wordID
int subtopic = topicAssignments.get(dIndex).get(wIndex);
int topic = subtopic % numTopics;
docTopicCount[dIndex][topic] -= 1;
if (subtopic == topic) { // LF(w|t) + LDA(t|d)
topicWordCountLF[topic][word] -= 1;
sumTopicWordCountLF[topic] -= 1;
}
else { // LDA(w|t) + LDA(t|d)
topicWordCountLDA[topic][word] -= 1;
sumTopicWordCountLDA[topic] -= 1;
}
// Sample a pair of topic z and binary indicator variable s
for (int tIndex = 0; tIndex < numTopics; tIndex++) {
multiPros[tIndex] = (docTopicCount[dIndex][tIndex] + alpha)
* lambda * (topicWordCountLF[tIndex][word] + beta)
/ (sumTopicWordCountLF[tIndex] + betaSum);
multiPros[tIndex + numTopics] = (docTopicCount[dIndex][tIndex] + alpha)
* (1 - lambda)
* (topicWordCountLDA[tIndex][word] + beta)
/ (sumTopicWordCountLDA[tIndex] + betaSum);
}
subtopic = FuncUtils.nextDiscrete(multiPros);
topic = subtopic % numTopics;
docTopicCount[dIndex][topic] += 1;
if (topic == subtopic) {
topicWordCountLF[topic][word] += 1;
sumTopicWordCountLF[topic] += 1;
}
else {
topicWordCountLDA[topic][word] += 1;
sumTopicWordCountLDA[topic] += 1;
}
// Update topic assignments
topicAssignments.get(dIndex).set(wIndex, subtopic);
}
}
}
public void writeParameters()
throws IOException
{
BufferedWriter writer = new BufferedWriter(new FileWriter(folderPath
+ expName + ".paras"));
writer.write("-model" + "\t" + "LFLDA");
writer.write("\n-corpus" + "\t" + corpusPath);
writer.write("\n-vectors" + "\t" + vectorFilePath);
writer.write("\n-ntopics" + "\t" + numTopics);
writer.write("\n-alpha" + "\t" + alpha);
writer.write("\n-beta" + "\t" + beta);
writer.write("\n-lambda" + "\t" + lambda);
writer.write("\n-initers" + "\t" + numInitIterations);
writer.write("\n-niters" + "\t" + numIterations);
writer.write("\n-twords" + "\t" + topWords);
writer.write("\n-name" + "\t" + expName);
if (savestep > 0)
writer.write("\n-sstep" + "\t" + savestep);
writer.close();
}
public void writeDictionary()
throws IOException
{
BufferedWriter writer = new BufferedWriter(new FileWriter(folderPath
+ expName + ".vocabulary"));
for (String word : word2IdVocabulary.keySet()) {
writer.write(word + " " + word2IdVocabulary.get(word) + "\n");
}
writer.close();
}
public void writeIDbasedCorpus()
throws IOException
{
BufferedWriter writer = new BufferedWriter(new FileWriter(folderPath
+ expName + ".IDcorpus"));
for (int dIndex = 0; dIndex < numDocuments; dIndex++) {
int docSize = corpus.get(dIndex).size();
for (int wIndex = 0; wIndex < docSize; wIndex++) {
writer.write(corpus.get(dIndex).get(wIndex) + " ");
}
writer.write("\n");
}
writer.close();
}
public void writeTopicAssignments()
throws IOException
{
BufferedWriter writer = new BufferedWriter(new FileWriter(folderPath
+ expName + ".topicAssignments"));
for (int dIndex = 0; dIndex < numDocuments; dIndex++) {
int docSize = corpus.get(dIndex).size();
for (int wIndex = 0; wIndex < docSize; wIndex++) {
writer.write(topicAssignments.get(dIndex).get(wIndex) + " ");
}
writer.write("\n");
}
writer.close();
}
public void writeTopicVectors()
throws IOException
{
BufferedWriter writer = new BufferedWriter(new FileWriter(folderPath
+ expName + ".topicVectors"));
for (int i = 0; i < numTopics; i++) {
for (int j = 0; j < vectorSize; j++)
writer.write(topicVectors[i][j] + " ");
writer.write("\n");
}
writer.close();
}
public void writeTopTopicalWords()
throws IOException
{
BufferedWriter writer = new BufferedWriter(new FileWriter(folderPath
+ expName + ".topWords"));
for (int tIndex = 0; tIndex < numTopics; tIndex++) {
writer.write("Topic" + new Integer(tIndex) + ":");
Map<Integer, Double> topicWordProbs = new TreeMap<Integer, Double>();
for (int wIndex = 0; wIndex < vocabularySize; wIndex++) {
double pro = lambda * expDotProductValues[tIndex][wIndex]
/ sumExpValues[tIndex] + (1 - lambda)
* (topicWordCountLDA[tIndex][wIndex] + beta)
/ (sumTopicWordCountLDA[tIndex] + betaSum);
topicWordProbs.put(wIndex, pro);
}
topicWordProbs = FuncUtils.sortByValueDescending(topicWordProbs);
Set<Integer> mostLikelyWords = topicWordProbs.keySet();
int count = 0;
for (Integer index : mostLikelyWords) {
if (count < topWords) {
writer.write(" " + id2WordVocabulary.get(index));
count += 1;
}
else {
writer.write("\n\n");
break;
}
}
}
writer.close();
}
public void writeTopicWordPros()
throws IOException
{
BufferedWriter writer = new BufferedWriter(new FileWriter(folderPath
+ expName + ".phi"));
for (int t = 0; t < numTopics; t++) {
for (int w = 0; w < vocabularySize; w++) {
double pro = lambda * expDotProductValues[t][w]
/ sumExpValues[t] + (1 - lambda)
* (topicWordCountLDA[t][w] + beta)
/ (sumTopicWordCountLDA[t] + betaSum);
writer.write(pro + " ");
}
writer.write("\n");
}
writer.close();
}
public void writeDocTopicPros()
throws IOException
{
BufferedWriter writer = new BufferedWriter(new FileWriter(folderPath
+ expName + ".theta"));
for (int i = 0; i < numDocuments; i++) {
for (int j = 0; j < numTopics; j++) {
double pro = (docTopicCount[i][j] + alpha)
/ (sumDocTopicCount[i] + alphaSum);
writer.write(pro + " ");
}
writer.write("\n");
}
writer.close();
}
public void write()
throws IOException
{
writeTopTopicalWords();
writeDocTopicPros();
writeTopicAssignments();
writeTopicWordPros();
}
public static void main(String args[])
throws Exception
{
LFLDA_Inf lflda = new LFLDA_Inf("test/testLFLDA.paras",
"test/corpus_test.txt", 2000, 200, 20, "testLFLDAinf", 0);
lflda.inference();
}
}
================================================
FILE: src/models/TopicVectorOptimizer.java
================================================
package models;
import cc.mallet.optimize.Optimizable;
import cc.mallet.types.MatrixOps;
/**
* Implementation of the MAP estimation for learning topic vectors, as described
* in section 3.5 in:
*
* Dat Quoc Nguyen, Richard Billingsley, Lan Du and Mark Johnson. 2015.
* Improving Topic Models with Latent Feature Word Representations. Transactions
* of the Association for Computational Linguistics, vol. 3, pp. 299-313.
*
* @author Dat Quoc Nguyen
*/
public class TopicVectorOptimizer
implements Optimizable.ByGradientValue
{
// Number of times a word type assigned to the topic
int[] wordCount;
int totalCount; // Total number of words assigned to the topic
int vocaSize; // Size of the vocabulary
// wordCount.length = wordVectors.length = vocaSize
double[][] wordVectors;// Vector representations for words
double[] topicVector;// Vector representation for a topic
int vectorSize; // vectorSize = topicVector.length
// For each i_{th} element of topic vector, compute:
// sum_w wordCount[w] * wordVectors[w][i]
double[] expectedCountValues;
double l2Constant; // L2 regularizer for learning topic vectors
double[] dotProductValues;
double[] expDotProductValues;
public TopicVectorOptimizer(double[] inTopicVector, int[] inWordCount,
double[][] inWordVectors, double inL2Constant)
{
vocaSize = inWordCount.length;
vectorSize = inWordVectors[0].length;
l2Constant = inL2Constant;
topicVector = new double[vectorSize];
System
.arraycopy(inTopicVector, 0, topicVector, 0, inTopicVector.length);
wordCount = new int[vocaSize];
System.arraycopy(inWordCount, 0, wordCount, 0, vocaSize);
wordVectors = new double[vocaSize][vectorSize];
for (int w = 0; w < vocaSize; w++)
System
.arraycopy(inWordVectors[w], 0, wordVectors[w], 0, vectorSize);
totalCount = 0;
for (int w = 0; w < vocaSize; w++) {
totalCount += wordCount[w];
}
expectedCountValues = new double[vectorSize];
for (int i = 0; i < vectorSize; i++) {
for (int w = 0; w < vocaSize; w++) {
expectedCountValues[i] += wordCount[w] * wordVectors[w][i];
}
}
dotProductValues = new double[vocaSize];
expDotProductValues = new double[vocaSize];
}
@Override
public int getNumParameters()
{
return vectorSize;
}
@Override
public void getParameters(double[] buffer)
{
for (int i = 0; i < vectorSize; i++)
buffer[i] = topicVector[i];
}
@Override
public double getParameter(int index)
{
return topicVector[index];
}
@Override
public void setParameters(double[] params)
{
for (int i = 0; i < params.length; i++)
topicVector[i] = params[i];
}
@Override
public void setParameter(int index, double value)
{
topicVector[index] = value;
}
@Override
public void getValueGradient(double[] buffer)
{
double partitionFuncValue = computePartitionFunction(dotProductValues,
expDotProductValues);
for (int i = 0; i < vectorSize; i++) {
buffer[i] = 0.0;
double expectationValue = 0.0;
for (int w = 0; w < vocaSize; w++) {
expectationValue += wordVectors[w][i] * expDotProductValues[w];
}
expectationValue = expectationValue / partitionFuncValue;
buffer[i] = expectedCountValues[i] - totalCount * expectationValue
- 2 * l2Constant * topicVector[i];
}
}
@Override
public double getValue()
{
double logPartitionFuncValue = Math.log(computePartitionFunction(
dotProductValues, expDotProductValues));
double value = 0.0;
for (int w = 0; w < vocaSize; w++) {
if (wordCount[w] == 0)
continue;
value += wordCount[w] * dotProductValues[w];
}
value = value - totalCount * logPartitionFuncValue - l2Constant
* MatrixOps.twoNormSquared(topicVector);
return value;
}
// Compute the partition function
public double computePartitionFunction(double[] elements1,
double[] elements2)
{
double value = 0.0;
for (int w = 0; w < vocaSize; w++) {
elements1[w] = MatrixOps.dotProduct(wordVectors[w], topicVector);
elements2[w] = Math.exp(elements1[w]);
value += elements2[w];
}
return value;
}
}
================================================
FILE: src/utility/CmdArgs.java
================================================
package utility;
import org.kohsuke.args4j.Option;
public class CmdArgs
{
@Option(name = "-model", usage = "Specify model", required = true)
public String model = "";
@Option(name = "-corpus", usage = "Specify path to topic modeling corpus")
public String corpus = "";
@Option(name = "-vectors", usage = "Specify path to the file containing word vectors")
public String vectors = "";
@Option(name = "-ntopics", usage = "Specify number of topics")
public int ntopics = 20;
@Option(name = "-alpha", usage = "Specify alpha")
public double alpha = 0.1;
@Option(name = "-beta", usage = "Specify beta")
public double beta = 0.01;
@Option(name = "-lambda", usage = "Specify mixture weight lambda")
public double lambda = 0.6;
@Option(name = "-initers", usage = "Specify number of initial sampling iterations")
public int initers = 2000;
@Option(name = "-niters", usage = "Specify number of EM-style sampling iterations")
public int niters = 200;
@Option(name = "-twords", usage = "Specify number of top topical words")
public int twords = 20;
@Option(name = "-name", usage = "Specify a name to a topic modeling experiment")
public String expModelName = "model";
@Option(name = "-initFile")
public String initTopicAssgns = "";
@Option(name = "-sstep")
public int savestep = 0;
@Option(name = "-dir")
public String dir = "";
@Option(name = "-label")
public String labelFile = "";
@Option(name = "-prob")
public String prob = "";
@Option(name = "-paras", usage = "Specify path to hyper-parameter file")
public String paras = "";
}
================================================
FILE: src/utility/FuncUtils.java
================================================
package utility;
import java.util.Collections;
import java.util.Comparator;
import java.util.LinkedHashMap;
import java.util.LinkedList;
import java.util.List;
import java.util.Map;
public class FuncUtils
{
public static <K, V extends Comparable<? super V>> Map<K, V> sortByValueDescending(Map<K, V> map)
{
List<Map.Entry<K, V>> list = new LinkedList<Map.Entry<K, V>>(map.entrySet());
Collections.sort(list, new Comparator<Map.Entry<K, V>>()
{
@Override
public int compare(Map.Entry<K, V> o1, Map.Entry<K, V> o2)
{
int compare = (o1.getValue()).compareTo(o2.getValue());
return -compare;
}
});
Map<K, V> result = new LinkedHashMap<K, V>();
for (Map.Entry<K, V> entry : list) {
result.put(entry.getKey(), entry.getValue());
}
return result;
}
public static <K, V extends Comparable<? super V>> Map<K, V> sortByValueAscending(Map<K, V> map)
{
List<Map.Entry<K, V>> list = new LinkedList<Map.Entry<K, V>>(map.entrySet());
Collections.sort(list, new Comparator<Map.Entry<K, V>>()
{
@Override
public int compare(Map.Entry<K, V> o1, Map.Entry<K, V> o2)
{
int compare = (o1.getValue()).compareTo(o2.getValue());
return compare;
}
});
Map<K, V> result = new LinkedHashMap<K, V>();
for (Map.Entry<K, V> entry : list) {
result.put(entry.getKey(), entry.getValue());
}
return result;
}
/**
* Sample a value from a double array
*
* @param probs
* @return
*/
public static int nextDiscrete(double[] probs)
{
double sum = 0.0;
for (int i = 0; i < probs.length; i++)
sum += probs[i];
double r = MTRandom.nextDouble() * sum;
sum = 0.0;
for (int i = 0; i < probs.length; i++) {
sum += probs[i];
if (sum > r)
return i;
}
return probs.length - 1;
}
public static double mean(double[] m)
{
double sum = 0;
for (int i = 0; i < m.length; i++)
sum += m[i];
return sum / m.length;
}
public static double stddev(double[] m)
{
double mean = mean(m);
double s = 0;
for (int i = 0; i < m.length; i++)
s += (m[i] - mean) * (m[i] - mean);
return Math.sqrt(s / m.length);
}
}
================================================
FILE: src/utility/LBFGS.java
================================================
/* Copyright (C) 2002 Univ. of Massachusetts Amherst, Computer Science Dept.
This file is part of "MALLET" (MAchine Learning for LanguagE Toolkit).
http://www.cs.umass.edu/~mccallum/mallet
This software is provided under the terms of the Common Public License,
version 1.0, as published by http://www.opensource.org. For further
information, see the file `LICENSE' included with this distribution. */
/**
@author Aron Culotta <a href="mailto:culotta@cs.umass.edu">culotta@cs.umass.edu</a>
*/
/**
Limited Memory BFGS, as described in Byrd, Nocedal, and Schnabel,
"Representations of Quasi-Newton Matrices and Their Use in Limited
Memory Methods"
*/
package utility;
import java.util.LinkedList;
import java.util.logging.Logger;
import cc.mallet.optimize.BackTrackLineSearch;
import cc.mallet.optimize.InvalidOptimizableException;
import cc.mallet.optimize.LineOptimizer;
import cc.mallet.optimize.Optimizable;
import cc.mallet.optimize.Optimizer;
import cc.mallet.optimize.OptimizerEvaluator;
import cc.mallet.types.MatrixOps;
import cc.mallet.util.MalletLogger;
public class LBFGS
implements Optimizer
{
private static Logger logger = MalletLogger
.getLogger("edu.umass.cs.mallet.base.ml.maximize.LimitedMemoryBFGS");
boolean converged = false;
Optimizable.ByGradientValue optimizable;
final int maxIterations = 5000;
// xxx need a more principled stopping point
// final double tolerance = .0001;
// private double tolerance = 1.0e-10;
// final double gradientTolerance = 1.0e-10;
double tolerance;// = 1.0e-3;
final double gradientTolerance = 1.0e-5;
final double eps = 1.0e-10;
// The number of corrections used in BFGS update
// ideally 3 <= m <= 7. Larger m means more cpu time, memory.
final int m = 10;
// Line search function
private LineOptimizer.ByGradient lineMaximizer;
public LBFGS(Optimizable.ByGradientValue function, double inTolerance)
{
tolerance = inTolerance;
this.optimizable = function;
lineMaximizer = new BackTrackLineSearch(function);
}
@Override
public Optimizable getOptimizable()
{
return this.optimizable;
}
@Override
public boolean isConverged()
{
return converged;
}
/**
* Sets the LineOptimizer.ByGradient to use in L-BFGS optimization.
*
* @param lineOpt
* line optimizer for L-BFGS
*/
public void setLineOptimizer(LineOptimizer.ByGradient lineOpt)
{
lineMaximizer = lineOpt;
}
// State of search
// g = gradient
// s = list of m previous "parameters" values
// y = list of m previous "g" values
// rho = intermediate calculation
double[] g, oldg, direction, parameters, oldParameters;
LinkedList s = new LinkedList();
LinkedList y = new LinkedList();
LinkedList rho = new LinkedList();
double[] alpha;
static double step = 1.0;
int iterations;
private OptimizerEvaluator.ByGradient eval = null;
// CPAL - added this
public void setTolerance(double newtol)
{
this.tolerance = newtol;
}
public void setEvaluator(OptimizerEvaluator.ByGradient eval)
{
this.eval = eval;
}
public int getIteration()
{
return iterations;
}
@Override
public boolean optimize()
{
return optimize(Integer.MAX_VALUE);
}
@Override
public boolean optimize(int numIterations)
{
double initialValue = optimizable.getValue();
logger.fine("Entering L-BFGS.optimize(). Initial Value=" + initialValue);
if (g == null) { // first time through
logger.fine("First time through L-BFGS");
iterations = 0;
s = new LinkedList();
y = new LinkedList();
rho = new LinkedList();
alpha = new double[m];
for (int i = 0; i < m; i++)
alpha[i] = 0.0;
parameters = new double[optimizable.getNumParameters()];
oldParameters = new double[optimizable.getNumParameters()];
g = new double[optimizable.getNumParameters()];
oldg = new double[optimizable.getNumParameters()];
direction = new double[optimizable.getNumParameters()];
optimizable.getParameters(parameters);
System.arraycopy(parameters, 0, oldParameters, 0, parameters.length);
optimizable.getValueGradient(g);
System.arraycopy(g, 0, oldg, 0, g.length);
System.arraycopy(g, 0, direction, 0, g.length);
if (MatrixOps.absNormalize(direction) == 0) {
logger.info("L-BFGS initial gradient is zero; saying converged");
g = null;
converged = true;
return true;
}
logger.fine("direction.2norm: " + MatrixOps.twoNorm(direction));
MatrixOps.timesEquals(direction, 1.0 / MatrixOps.twoNorm(direction));
// make initial jump
logger.fine("before initial jump: \ndirection.2norm: " + MatrixOps.twoNorm(direction)
+ " \ngradient.2norm: " + MatrixOps.twoNorm(g) + "\nparameters.2norm: "
+ MatrixOps.twoNorm(parameters));
// TestMaximizable.testValueAndGradientInDirection (maxable,
// direction);
step = lineMaximizer.optimize(direction, step);
if (step == 0.0) {// could not step in this direction.
// // give up and say converged.
// g = null; // reset search
// step = 1.0;
// throw new OptimizationException(
// "Line search could not step in the current direction. "
// +
// "(This is not necessarily cause for alarm. Sometimes this happens close to the maximum,"
// + " where the function may be very flat.)");
return false;
}
optimizable.getParameters(parameters);
optimizable.getValueGradient(g);
logger.fine("after initial jump: \ndirection.2norm: " + MatrixOps.twoNorm(direction)
+ " \ngradient.2norm: " + MatrixOps.twoNorm(g));
}
double value = optimizable.getValue();
for (int iterationCount = 0; iterationCount < numIterations; iterationCount++) {
logger.fine("L-BFGS iteration=" + iterationCount + ", value=" + value + " g.twoNorm: "
+ MatrixOps.twoNorm(g) + " oldg.twoNorm: " + MatrixOps.twoNorm(oldg));
// if (iterationCount % 10 == 0)
// System.out.println("\t\tL-BFGS iteration=" + iterationCount
// + ", value=" + value + " g.twoNorm: "
// + MatrixOps.twoNorm(g) + " oldg.twoNorm: "
// + MatrixOps.twoNorm(oldg));
// get difference between previous 2 gradients and parameters
double sy = 0.0;
double yy = 0.0;
for (int i = 0; i < oldParameters.length; i++) {
// -inf - (-inf) = 0; inf - inf = 0
if (Double.isInfinite(parameters[i]) && Double.isInfinite(oldParameters[i])
&& (parameters[i] * oldParameters[i] > 0))
oldParameters[i] = 0.0;
else
oldParameters[i] = parameters[i] - oldParameters[i];
if (Double.isInfinite(g[i]) && Double.isInfinite(oldg[i]) && (g[i] * oldg[i] > 0))
oldg[i] = 0.0;
else
oldg[i] = g[i] - oldg[i];
sy += oldParameters[i] * oldg[i]; // si * yi
yy += oldg[i] * oldg[i];
direction[i] = g[i];
}
if (sy > 0) {
throw new InvalidOptimizableException("sy = " + sy + " > 0");
}
double gamma = sy / yy; // scaling factor
if (gamma > 0)
throw new InvalidOptimizableException("gamma = " + gamma + " > 0");
push(rho, 1.0 / sy);
push(s, oldParameters);
push(y, oldg);
// calculate new direction
assert (s.size() == y.size()) : "s.size: " + s.size() + " y.size: " + y.size();
for (int i = s.size() - 1; i >= 0; i--) {
alpha[i] = ((Double) rho.get(i)).doubleValue()
* MatrixOps.dotProduct((double[]) s.get(i), direction);
MatrixOps.plusEquals(direction, (double[]) y.get(i), -1.0 * alpha[i]);
}
MatrixOps.timesEquals(direction, gamma);
for (int i = 0; i < y.size(); i++) {
double beta = (((Double) rho.get(i)).doubleValue())
* MatrixOps.dotProduct((double[]) y.get(i), direction);
MatrixOps.plusEquals(direction, (double[]) s.get(i), alpha[i] - beta);
}
for (int i = 0; i < oldg.length; i++) {
oldParameters[i] = parameters[i];
oldg[i] = g[i];
direction[i] *= -1.0;
}
logger.fine("before linesearch: direction.gradient.dotprod: "
+ MatrixOps.dotProduct(direction, g) + "\ndirection.2norm: "
+ MatrixOps.twoNorm(direction) + "\nparameters.2norm: "
+ MatrixOps.twoNorm(parameters));
// TestMaximizable.testValueAndGradientInDirection (maxable,
// direction);
step = lineMaximizer.optimize(direction, step);
if (step == 0.0) { // could not step in this direction.
g = null; // reset search
step = 1.0;
// xxx Temporary test; passed OK
// TestMaximizable.testValueAndGradientInDirection (maxable,
// direction);
// System.out
// .println("\t\tLine search could not step in the current direction.");
// throw new OptimizationException(
// "Line search could not step in the current direction. "
// +
// "(This is not necessarily cause for alarm. Sometimes this happens close to the maximum,"
// + " where the function may be very flat.)");
return false;
}
optimizable.getParameters(parameters);
optimizable.getValueGradient(g);
logger.fine("after linesearch: direction.2norm: " + MatrixOps.twoNorm(direction));
double newValue = optimizable.getValue();
// Test for terminations
// if(2.0*Math.abs(newValue-value) <= tolerance*
// (Math.abs(newValue)+Math.abs(value) + eps)){
if (Math.abs(newValue - value) <= tolerance) {
// System.out.println("\t\tNumber of iterations: "
// + iterationCount);
// System.out
// .println("\t\tExiting L-BFGS on termination #1:\n\t\tvalue difference below "
// + tolerance
// + " (oldValue: "
// + value
// + " newValue: "
// + newValue
// + " gradient.twoNorm: "
// + MatrixOps.twoNorm(g) + ")");
converged = true;
return true;
}
value = newValue;
double gg = MatrixOps.twoNorm(g);
if (gg < gradientTolerance) {
logger.fine("Exiting L-BFGS on termination #2: \ngradient=" + gg + " < "
+ gradientTolerance);
converged = true;
return true;
}
if (gg == 0.0) {
logger.fine("Exiting L-BFGS on termination #3: \ngradient==0.0");
converged = true;
return true;
}
logger.fine("Gradient = " + gg);
iterations++;
if (iterations > maxIterations) {
System.err
.println("Too many iterations in L-BFGS.java. Continuing with current parameters.");
converged = true;
return true;
// throw new IllegalStateException ("Too many iterations.");
}
// end of iteration. call evaluator
if (eval != null && !eval.evaluate(optimizable, iterationCount)) {
logger.fine("Exiting L-BFGS on termination #4: evaluator returned false.");
converged = true;
return false;
}
}
return false;
}
/**
* Resets the previous gradients and values that are used to approximate the Hessian. NOTE - If
* the {@link Optimizable} object is modified externally, this method should be called to avoid
* IllegalStateExceptions.
*/
public void reset()
{
g = null;
}
/**
* Pushes a new object onto the queue l
*
* @param l
* linked list queue of Matrix obj's
* @param toadd
* matrix to push onto queue
*/
private void push(LinkedList l, double[] toadd)
{
assert (l.size() <= m);
if (l.size() == m) {
// remove oldest matrix and add newset to end of list.
// to make this more efficient, actually overwrite
// memory of oldest matrix
// this overwrites the oldest matrix
double[] last = (double[]) l.get(0);
System.arraycopy(toadd, 0, last, 0, toadd.length);
Object ptr = last;
// this readjusts the pointers in the list
for (int i = 0; i < l.size() - 1; i++)
l.set(i, l.get(i + 1));
l.set(m - 1, ptr);
}
else {
double[] newArray = new double[toadd.length];
System.arraycopy(toadd, 0, newArray, 0, toadd.length);
l.addLast(newArray);
}
}
/**
* Pushes a new object onto the queue l
*
* @param l
* linked list queue of Double obj's
* @param toadd
* double value to push onto queue
*/
private void push(LinkedList l, double toadd)
{
assert (l.size() <= m);
if (l.size() == m) { // pop old double and add new
l.removeFirst();
l.addLast(new Double(toadd));
}
else
l.addLast(new Double(toadd));
}
}
================================================
FILE: src/utility/MTRandom.java
================================================
package utility;
public class MTRandom
{
private static MersenneTwister rand = new MersenneTwister();
public static void setSeed(long seed)
{
rand.setSeed(seed);
}
public static double nextDouble()
{
return rand.nextDouble();
}
public static int nextInt(int n)
{
return rand.nextInt(n);
}
public static boolean nextBoolean()
{
return rand.nextBoolean();
}
}
================================================
FILE: src/utility/MersenneTwister.java
================================================
package utility;
import java.io.DataInputStream;
import java.io.DataOutputStream;
import java.io.IOException;
import java.io.ObjectInputStream;
import java.io.ObjectOutputStream;
import java.io.Serializable;
/**
* <h3>MersenneTwister and MersenneTwisterFast</h3>
* <p>
* <b>Version 20</b>, based on version MT199937(99/10/29) of the Mersenne Twister algorithm found at
* <a href="http://www.math.keio.ac.jp/matumoto/emt.html"> The Mersenne Twister Home Page</a>, with
* the initialization improved using the new 2002/1/26 initialization algorithm By Sean Luke,
* October 2004.
*
* <p>
* <b>MersenneTwister</b> is a drop-in subclass replacement for java.util.Random. It is properly
* synchronized and can be used in a multithreaded environment. On modern VMs such as HotSpot, it is
* approximately 1/3 slower than java.util.Random.
*
* <p>
* <b>MersenneTwisterFast</b> is not a subclass of java.util.Random. It has the same public methods
* as Random does, however, and it is algorithmically identical to MersenneTwister.
* MersenneTwisterFast has hard-code inlined all of its methods directly, and made all of them final
* (well, the ones of consequence anyway). Further, these methods are <i>not</i> synchronized, so
* the same MersenneTwisterFast instance cannot be shared by multiple threads. But all this helps
* MersenneTwisterFast achieve well over twice the speed of MersenneTwister. java.util.Random is
* about 1/3 slower than MersenneTwisterFast.
*
* <h3>About the Mersenne Twister</h3>
* <p>
* This is a Java version of the C-program for MT19937: Integer version. The MT19937 algorithm was
* created by Makoto Matsumoto and Takuji Nishimura, who ask: "When you use this, send an email to:
* matumoto@math.keio.ac.jp with an appropriate reference to your work". Indicate that this is a
* translation of their algorithm into Java.
*
* <p>
* <b>Reference. </b> Makato Matsumoto and Takuji Nishimura, "Mersenne Twister: A 623-Dimensionally
* Equidistributed Uniform Pseudo-Random Number Generator", <i>ACM Transactions on Modeling and.
* Computer Simulation,</i> Vol. 8, No. 1, January 1998, pp 3--30.
*
* <h3>About this Version</h3>
*
* <p>
* <b>Changes since V19:</b> nextFloat(boolean, boolean) now returns float, not double.
*
* <p>
* <b>Changes since V18:</b> Removed old final declarations, which used to potentially speed up the
* code, but no longer.
*
* <p>
* <b>Changes since V17:</b> Removed vestigial references to &= 0xffffffff which stemmed from the
* original C code. The C code could not guarantee that ints were 32 bit, hence the masks. The
* vestigial references in the Java code were likely optimized out anyway.
*
* <p>
* <b>Changes since V16:</b> Added nextDouble(includeZero, includeOne) and nextFloat(includeZero,
* includeOne) to allow for half-open, fully-closed, and fully-open intervals.
*
* <p>
* <b>Changes Since V15:</b> Added serialVersionUID to quiet compiler warnings from Sun's overly
* verbose compilers as of JDK 1.5.
*
* <p>
* <b>Changes Since V14:</b> made strictfp, with StrictMath.log and StrictMath.sqrt in nextGaussian
* instead of Math.log and Math.sqrt. This is largely just to be safe, as it presently makes no
* difference in the speed, correctness, or results of the algorithm.
*
* <p>
* <b>Changes Since V13:</b> clone() method CloneNotSupportedException removed.
*
* <p>
* <b>Changes Since V12:</b> clone() method added.
*
* <p>
* <b>Changes Since V11:</b> stateEquals(...) method added. MersenneTwisterFast is equal to other
* MersenneTwisterFasts with identical state; likewise MersenneTwister is equal to other
* MersenneTwister with identical state. This isn't equals(...) because that requires a contract of
* immutability to compare by value.
*
* <p>
* <b>Changes Since V10:</b> A documentation error suggested that setSeed(int[]) required an int[]
* array 624 long. In fact, the array can be any non-zero length. The new version also checks for
* this fact.
*
* <p>
* <b>Changes Since V9:</b> readState(stream) and writeState(stream) provided.
*
* <p>
* <b>Changes Since V8:</b> setSeed(int) was only using the first 28 bits of the seed; it should
* have been 32 bits. For small-number seeds the behavior is identical.
*
* <p>
* <b>Changes Since V7:</b> A documentation error in MersenneTwisterFast (but not MersenneTwister)
* stated that nextDouble selects uniformly from the full-open interval [0,1]. It does not.
* nextDouble's contract is identical across MersenneTwisterFast, MersenneTwister, and
* java.util.Random, namely, selection in the half-open interval [0,1). That is, 1.0 should not be
* returned. A similar contract exists in nextFloat.
*
* <p>
* <b>Changes Since V6:</b> License has changed from LGPL to BSD. New timing information to compare
* against java.util.Random. Recent versions of HotSpot have helped Random increase in speed to the
* point where it is faster than MersenneTwister but slower than MersenneTwisterFast (which should
* be the case, as it's a less complex algorithm but is synchronized).
*
* <p>
* <b>Changes Since V5:</b> New empty constructor made to work the same as java.util.Random --
* namely, it seeds based on the current time in milliseconds.
*
* <p>
* <b>Changes Since V4:</b> New initialization algorithms. See (see <a
* href="http://www.math.keio.ac.jp/matumoto/MT2002/emt19937ar.html"</a>
* http://www.math.keio.ac.jp/matumoto/MT2002/emt19937ar.html</a>)
*
* <p>
* The MersenneTwister code is based on standard MT19937 C/C++ code by Takuji Nishimura, with
* suggestions from Topher Cooper and Marc Rieffel, July 1997. The code was originally translated
* into Java by Michael Lecuyer, January 1999, and the original code is Copyright (c) 1999 by
* Michael Lecuyer.
*
* <h3>Java notes</h3>
*
* <p>
* This implementation implements the bug fixes made in Java 1.2's version of Random, which means it
* can be used with earlier versions of Java. See <a
* href="http://www.javasoft.com/products/jdk/1.2/docs/api/java/util/Random.html"> the JDK 1.2
* java.util.Random documentation</a> for further documentation on the random-number generation
* contracts made. Additionally, there's an undocumented bug in the JDK java.util.Random.nextBytes()
* method, which this code fixes.
*
* <p>
* Just like java.util.Random, this generator accepts a long seed but doesn't use all of it.
* java.util.Random uses 48 bits. The Mersenne Twister instead uses 32 bits (int size). So it's best
* if your seed does not exceed the int range.
*
* <p>
* MersenneTwister can be used reliably on JDK version 1.1.5 or above. Earlier Java versions have
* serious bugs in java.util.Random; only MersenneTwisterFast (and not MersenneTwister nor
* java.util.Random) should be used with them.
*
* <h3>License</h3>
*
* Copyright (c) 2003 by Sean Luke. <br>
* Portions copyright (c) 1993 by Michael Lecuyer. <br>
* All rights reserved. <br>
*
* <p>
* Redistribution and use in source and binary forms, with or without modification, are permitted
* provided that the following conditions are met:
* <ul>
* <li>Redistributions of source code must retain the above copyright notice, this list of
* conditions and the following disclaimer.
* <li>Redistributions in binary form must reproduce the above copyright notice, this list of
* conditions and the following disclaimer in the documentation and/or other materials provided with
* the distribution.
* <li>Neither the name of the copyright owners, their employers, nor the names of its contributors
* may be used to endorse or promote products derived from this software without specific prior
* written permission.
* </ul>
* <p>
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND ANY EXPRESS OR
* IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND
* FITNESS FOR A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNERS OR
* CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
* DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE,
* DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY,
* WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY
* WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
*
* @version 20
*/
public strictfp class MersenneTwister
extends java.util.Random
implements Serializable, Cloneable
{
// Serialization
private static final long serialVersionUID = -4035832775130174188L; // locked as of Version 15
// Period parameters
private static final int N = 624;
private static final int M = 397;
private static final int MATRIX_A = 0x9908b0df; // private static final * constant vector a
private static final int UPPER_MASK = 0x80000000; // most significant w-r bits
private static final int LOWER_MASK = 0x7fffffff; // least significant r bits
// Tempering parameters
private static final int TEMPERING_MASK_B = 0x9d2c5680;
private static final int TEMPERING_MASK_C = 0xefc60000;
private int mt[]; // the array for the state vector
private int mti; // mti==N+1 means mt[N] is not initialized
private int mag01[];
// a good initial seed (of int size, though stored in a long)
// private static final long GOOD_SEED = 4357;
/*
* implemented here because there's a bug in Random's implementation of the Gaussian code
* (divide by zero, and log(0), ugh!), yet its gaussian variables are private so we can't access
* them here. :-(
*/
private double __nextNextGaussian;
private boolean __haveNextNextGaussian;
/* We're overriding all internal data, to my knowledge, so this should be okay */
public Object clone()
{
try {
MersenneTwister f = (MersenneTwister) (super.clone());
f.mt = (int[]) (mt.clone());
f.mag01 = (int[]) (mag01.clone());
return f;
}
catch (CloneNotSupportedException e) {
throw new InternalError();
} // should never happen
}
public boolean stateEquals(Object o)
{
if (o == this)
return true;
if (o == null || !(o instanceof MersenneTwister))
return false;
MersenneTwister other = (MersenneTwister) o;
if (mti != other.mti)
return false;
for (int x = 0; x < mag01.length; x++)
if (mag01[x] != other.mag01[x])
return false;
for (int x = 0; x < mt.length; x++)
if (mt[x] != other.mt[x])
return false;
return true;
}
/** Reads the entire state of the MersenneTwister RNG from the stream */
public void readState(DataInputStream stream)
throws IOException
{
int len = mt.length;
for (int x = 0; x < len; x++)
mt[x] = stream.readInt();
len = mag01.length;
for (int x = 0; x < len; x++)
mag01[x] = stream.readInt();
mti = stream.readInt();
__nextNextGaussian = stream.readDouble();
__haveNextNextGaussian = stream.readBoolean();
}
/** Writes the entire state of the MersenneTwister RNG to the stream */
public void writeState(DataOutputStream stream)
throws IOException
{
int len = mt.length;
for (int x = 0; x < len; x++)
stream.writeInt(mt[x]);
len = mag01.length;
for (int x = 0; x < len; x++)
stream.writeInt(mag01[x]);
stream.writeInt(mti);
stream.writeDouble(__nextNextGaussian);
stream.writeBoolean(__haveNextNextGaussian);
}
/**
* Constructor using the default seed.
*/
public MersenneTwister()
{
this(System.currentTimeMillis());
}
/**
* Constructor using a given seed. Though you pass this seed in as a long, it's best to make
* sure it's actually an integer.
*/
public MersenneTwister(long seed)
{
super(seed); /* just in case */
setSeed(seed);
}
/**
* Constructor using an array of integers as seed. Your array must have a non-zero length. Only
* the first 624 integers in the array are used; if the array is shorter than this then integers
* are repeatedly used in a wrap-around fashion.
*/
public MersenneTwister(int[] array)
{
super(System.currentTimeMillis()); /* pick something at random just in case */
setSeed(array);
}
/**
* Initalize the pseudo random number generator. Don't pass in a long that's bigger than an int
* (Mersenne Twister only uses the first 32 bits for its seed).
*/
synchronized public void setSeed(long seed)
{
// it's always good style to call super
super.setSeed(seed);
// Due to a bug in java.util.Random clear up to 1.2, we're
// doing our own Gaussian variable.
__haveNextNextGaussian = false;
mt = new int[N];
mag01 = new int[2];
mag01[0] = 0x0;
mag01[1] = MATRIX_A;
mt[0] = (int) (seed & 0xffffffff);
mt[0] = (int) seed;
for (mti = 1; mti < N; mti++) {
mt[mti] = (1812433253 * (mt[mti - 1] ^ (mt[mti - 1] >>> 30)) + mti);
/* See Knuth TAOCP Vol2. 3rd Ed. P.106 for multiplier. */
/* In the previous versions, MSBs of the seed affect */
/* only MSBs of the array mt[]. */
/* 2002/01/09 modified by Makoto Matsumoto */
// mt[mti] &= 0xffffffff;
/* for >32 bit machines */
}
}
/**
* Sets the seed of the MersenneTwister using an array of integers. Your array must have a
* non-zero length. Only the first 624 integers in the array are used; if the array is shorter
* than this then integers are repeatedly used in a wrap-around fashion.
*/
synchronized public void setSeed(int[] array)
{
if (array.length == 0)
throw new IllegalArgumentException("Array length must be greater than zero");
int i, j, k;
setSeed(19650218);
i = 1;
j = 0;
k = (N > array.length ? N : array.length);
for (; k != 0; k--) {
mt[i] = (mt[i] ^ ((mt[i - 1] ^ (mt[i - 1] >>> 30)) * 1664525)) + array[j] + j; /*
* non
* linear
*/
// mt[i] &= 0xffffffff; /* for WORDSIZE > 32 machines */
i++;
j++;
if (i >= N) {
mt[0] = mt[N - 1];
i = 1;
}
if (j >= array.length)
j = 0;
}
for (k = N - 1; k != 0; k--) {
mt[i] = (mt[i] ^ ((mt[i - 1] ^ (mt[i - 1] >>> 30)) * 1566083941)) - i; /* non linear */
// mt[i] &= 0xffffffff; /* for WORDSIZE > 32 machines */
i++;
if (i >= N) {
mt[0] = mt[N - 1];
i = 1;
}
}
mt[0] = 0x80000000; /* MSB is 1; assuring non-zero initial array */
}
/**
* Returns an integer with <i>bits</i> bits filled with a random number.
*/
synchronized protected int next(int bits)
{
int y;
if (mti >= N) // generate N words at one time
{
int kk;
final int[] mt = this.mt; // locals are slightly faster
final int[] mag01 = this.mag01; // locals are slightly faster
for (kk = 0; kk < N - M; kk++) {
y = (mt[kk] & UPPER_MASK) | (mt[kk + 1] & LOWER_MASK);
mt[kk] = mt[kk + M] ^ (y >>> 1) ^ mag01[y & 0x1];
}
for (; kk < N - 1; kk++) {
y = (mt[kk] & UPPER_MASK) | (mt[kk + 1] & LOWER_MASK);
mt[kk] = mt[kk + (M - N)] ^ (y >>> 1) ^ mag01[y & 0x1];
}
y = (mt[N - 1] & UPPER_MASK) | (mt[0] & LOWER_MASK);
mt[N - 1] = mt[M - 1] ^ (y >>> 1) ^ mag01[y & 0x1];
mti = 0;
}
y = mt[mti++];
y ^= y >>> 11; // TEMPERING_SHIFT_U(y)
y ^= (y << 7) & TEMPERING_MASK_B; // TEMPERING_SHIFT_S(y)
y ^= (y << 15) & TEMPERING_MASK_C; // TEMPERING_SHIFT_T(y)
y ^= (y >>> 18); // TEMPERING_SHIFT_L(y)
return y >>> (32 - bits); // hope that's right!
}
/*
* If you've got a truly old version of Java, you can omit these two next methods.
*/
private synchronized void writeObject(ObjectOutputStream out)
throws IOException
{
// just so we're synchronized.
out.defaultWriteObject();
}
private synchronized void readObject(ObjectInputStream in)
throws IOException, ClassNotFoundException
{
// just so we're synchronized.
in.defaultReadObject();
}
/**
* This method is missing from jdk 1.0.x and below. JDK 1.1 includes this for us, but what the
* heck.
*/
public boolean nextBoolean()
{
return next(1) != 0;
}
/**
* This generates a coin flip with a probability <tt>probability</tt> of returning true, else
* returning false. <tt>probability</tt> must be between 0.0 and 1.0, inclusive. Not as precise
* a random real event as nextBoolean(double), but twice as fast. To explicitly use this,
* remember you may need to cast to float first.
*/
public boolean nextBoolean(float probability)
{
if (probability < 0.0f || probability > 1.0f)
throw new IllegalArgumentException("probability must be between 0.0 and 1.0 inclusive.");
if (probability == 0.0f)
return false; // fix half-open issues
else if (probability == 1.0f)
return true; // fix half-open issues
return nextFloat() < probability;
}
/**
* This generates a coin flip with a probability <tt>probability</tt> of returning true, else
* returning false. <tt>probability</tt> must be between 0.0 and 1.0, inclusive.
*/
public boolean nextBoolean(double probability)
{
if (probability < 0.0 || probability > 1.0)
throw new IllegalArgumentException("probability must be between 0.0 and 1.0 inclusive.");
if (probability == 0.0)
return false; // fix half-open issues
else if (probability == 1.0)
return true; // fix half-open issues
return nextDouble() < probability;
}
/**
* This method is missing from JDK 1.1 and below. JDK 1.2 includes this for us, but what the
* heck.
*/
public int nextInt(int n)
{
if (n <= 0)
throw new IllegalArgumentException("n must be positive, got: " + n);
if ((n & -n) == n)
return (int) ((n * (long) next(31)) >> 31);
int bits, val;
do {
bits = next(31);
val = bits % n;
}
while (bits - val + (n - 1) < 0);
return val;
}
/**
* This method is for completness' sake. Returns a long drawn uniformly from 0 to n-1. Suffice
* it to say, n must be > 0, or an IllegalArgumentException is raised.
*/
public long nextLong(long n)
{
if (n <= 0)
throw new IllegalArgumentException("n must be positive, got: " + n);
long bits, val;
do {
bits = (nextLong() >>> 1);
val = bits % n;
}
while (bits - val + (n - 1) < 0);
return val;
}
/**
* A bug fix for versions of JDK 1.1 and below. JDK 1.2 fixes this for us, but what the heck.
*/
public double nextDouble()
{
return (((long) next(26) << 27) + next(27)) / (double) (1L << 53);
}
/**
* Returns a double in the range from 0.0 to 1.0, possibly inclusive of 0.0 and 1.0 themselves.
* Thus:
*
* <p>
* <table border=0>
* <th>
* <td>Expression
* <td>Interval
* <tr>
* <td>nextDouble(false, false)
* <td>(0.0, 1.0)
* <tr>
* <td>nextDouble(true, false)
* <td>[0.0, 1.0)
* <tr>
* <td>nextDouble(false, true)
* <td>(0.0, 1.0]
* <tr>
* <td>nextDouble(true, true)
* <td>[0.0, 1.0]
* </table>
*
* <p>
* This version preserves all possible random values in the double range.
*/
public double nextDouble(boolean includeZero, boolean includeOne)
{
double d = 0.0;
do {
d = nextDouble(); // grab a value, initially from half-open [0.0, 1.0)
if (includeOne && nextBoolean())
d += 1.0; // if includeOne, with 1/2 probability, push to [1.0, 2.0)
}
while ((d > 1.0) || // everything above 1.0 is always invalid
(!includeZero && d == 0.0)); // if we're not including zero, 0.0 is invalid
return d;
}
/**
* A bug fix for versions of JDK 1.1 and below. JDK 1.2 fixes this for us, but what the heck.
*/
public float nextFloat()
{
return next(24) / ((float) (1 << 24));
}
/**
* Returns a float in the range from 0.0f to 1.0f, possibly inclusive of 0.0f and 1.0f
* themselves. Thus:
*
* <p>
* <table border=0>
* <th>
* <td>Expression
* <td>Interval
* <tr>
* <td>nextFloat(false, false)
* <td>(0.0f, 1.0f)
* <tr>
* <td>nextFloat(true, false)
* <td>[0.0f, 1.0f)
* <tr>
* <td>nextFloat(false, true)
* <td>(0.0f, 1.0f]
* <tr>
* <td>nextFloat(true, true)
* <td>[0.0f, 1.0f]
* </table>
*
* <p>
* This version preserves all possible random values in the float range.
*/
public float nextFloat(boolean includeZero, boolean includeOne)
{
float d = 0.0f;
do {
d = nextFloat(); // grab a value, initially from half-open [0.0f, 1.0f)
if (includeOne && nextBoolean())
d += 1.0f; // if includeOne, with 1/2 probability, push to [1.0f, 2.0f)
}
while ((d > 1.0f) || // everything above 1.0f is always invalid
(!includeZero && d == 0.0f)); // if we're not including zero, 0.0f is invalid
return d;
}
/**
* A bug fix for all versions of the JDK. The JDK appears to use all four bytes in an integer as
* independent byte values! Totally wrong. I've submitted a bug report.
*/
public void nextBytes(byte[] bytes)
{
for (int x = 0; x < bytes.length; x++)
bytes[x] = (byte) next(8);
}
/** For completeness' sake, though it's not in java.util.Random. */
public char nextChar()
{
// chars are 16-bit UniCode values
return (char) (next(16));
}
/** For completeness' sake, though it's not in java.util.Random. */
public short nextShort()
{
return (short) (next(16));
}
/** For completeness' sake, though it's not in java.util.Random. */
public byte nextByte()
{
return (byte) (next(8));
}
/**
* A bug fix for all JDK code including 1.2. nextGaussian can theoretically ask for the log of 0
* and divide it by 0! See Java bug <a
* href="http://developer.java.sun.com/developer/bugParade/bugs/4254501.html">
* http://developer.java.sun.com/developer/bugParade/bugs/4254501.html</a>
*/
synchronized public double nextGaussian()
{
if (__haveNextNextGaussian) {
__haveNextNextGaussian = false;
return __nextNextGaussian;
}
else {
double v1, v2, s;
do {
v1 = 2 * nextDouble() - 1; // between -1.0 and 1.0
v2 = 2 * nextDouble() - 1; // between -1.0 and 1.0
s = v1 * v1 + v2 * v2;
}
while (s >= 1 || s == 0);
double multiplier = StrictMath.sqrt(-2 * StrictMath.log(s) / s);
__nextNextGaussian = v2 * multiplier;
__haveNextNextGaussian = true;
return v1 * multiplier;
}
}
/**
* Tests the code.
*/
public static void main(String args[])
{
int j;
MersenneTwister r;
// CORRECTNESS TEST
// COMPARE WITH http://www.math.keio.ac.jp/matumoto/CODES/MT2002/mt19937ar.out
r = new MersenneTwister(new int[] { 0x123, 0x234, 0x345, 0x456 });
System.out.println("Output of MersenneTwister with new (2002/1/26) seeding mechanism");
for (j = 0; j < 1000; j++) {
// first, convert the int from signed to "unsigned"
long l = (long) r.nextInt();
if (l < 0)
l += 4294967296L; // max int value
String s = String.valueOf(l);
while (s.length() < 10)
s = " " + s; // buffer
System.out.print(s + " ");
if (j % 5 == 4)
System.out.println();
}
// SPEED TEST
final long SEED = 4357;
int xx;
long ms;
System.out.println("\nTime to test grabbing 100000000 ints");
r = new MersenneTwister(SEED);
ms = System.currentTimeMillis();
xx = 0;
for (j = 0; j < 100000000; j++)
xx += r.nextInt();
System.out.println("Mersenne Twister: " + (System.currentTimeMillis() - ms)
+ " Ignore this: " + xx);
System.out
.println("To compare this with java.util.Random, run this same test on MersenneTwisterFast.");
System.out
.println("The comparison with Random is removed from MersenneTwister because it is a proper");
System.out
.println("subclass of Random and this unfairly makes some of Random's methods un-inlinable,");
System.out.println("so it would make Random look worse than it is.");
// TEST TO COMPARE TYPE CONVERSION BETWEEN
// MersenneTwisterFast.java AND MersenneTwister.java
System.out.println("\nGrab the first 1000 booleans");
r = new MersenneTwister(SEED);
for (j = 0; j < 1000; j++) {
System.out.print(r.nextBoolean() + " ");
if (j % 8 == 7)
System.out.println();
}
if (!(j % 8 == 7))
System.out.println();
System.out
.println("\nGrab 1000 booleans of increasing probability using nextBoolean(double)");
r = new MersenneTwister(SEED);
for (j = 0; j < 1000; j++) {
System.out.print(r.nextBoolean((double) (j / 999.0)) + " ");
if (j % 8 == 7)
System.out.println();
}
if (!(j % 8 == 7))
System.out.println();
System.out
.println("\nGrab 1000 booleans of increasing probability using nextBoolean(float)");
r = new MersenneTwister(SEED);
for (j = 0; j < 1000; j++) {
System.out.print(r.nextBoolean((float) (j / 999.0f)) + " ");
if (j % 8 == 7)
System.out.println();
}
if (!(j % 8 == 7))
System.out.println();
byte[] bytes = new byte[1000];
System.out.println("\nGrab the first 1000 bytes using nextBytes");
r = new MersenneTwister(SEED);
r.nextBytes(bytes);
for (j = 0; j < 1000; j++) {
System.out.print(bytes[j] + " ");
if (j % 16 == 15)
System.out.println();
}
if (!(j % 16 == 15))
System.out.println();
byte b;
System.out.println("\nGrab the first 1000 bytes -- must be same as nextBytes");
r = new MersenneTwister(SEED);
for (j = 0; j < 1000; j++) {
System.out.print((b = r.nextByte()) + " ");
if (b != bytes[j])
System.out.print("BAD ");
if (j % 16 == 15)
System.out.println();
}
if (!(j % 16 == 15))
System.out.println();
System.out.println("\nGrab the first 1000 shorts");
r = new MersenneTwister(SEED);
for (j = 0; j < 1000; j++) {
System.out.print(r.nextShort() + " ");
if (j % 8 == 7)
System.out.println();
}
if (!(j % 8 == 7))
System.out.println();
System.out.println("\nGrab the first 1000 ints");
r = new MersenneTwister(SEED);
for (j = 0; j < 1000; j++) {
System.out.print(r.nextInt() + " ");
if (j % 4 == 3)
System.out.println();
}
if (!(j % 4 == 3))
System.out.println();
System.out.println("\nGrab the first 1000 ints of different sizes");
r = new MersenneTwister(SEED);
int max = 1;
for (j = 0; j < 1000; j++) {
System.out.print(r.nextInt(max) + " ");
max *= 2;
if (max <= 0)
max = 1;
if (j % 4 == 3)
System.out.println();
}
if (!(j % 4 == 3))
System.out.println();
System.out.println("\nGrab the first 1000 longs");
r = new MersenneTwister(SEED);
for (j = 0; j < 1000; j++) {
System.out.print(r.nextLong() + " ");
if (j % 3 == 2)
System.out.println();
}
if (!(j % 3 == 2))
System.out.println();
System.out.println("\nGrab the first 1000 longs of different sizes");
r = new MersenneTwister(SEED);
long max2 = 1;
for (j = 0; j < 1000; j++) {
System.out.print(r.nextLong(max2) + " ");
max2 *= 2;
if (max2 <= 0)
max2 = 1;
if (j % 4 == 3)
System.out.println();
}
if (!(j % 4 == 3))
System.out.println();
System.out.println("\nGrab the first 1000 floats");
r = new MersenneTwister(SEED);
for (j = 0; j < 1000; j++) {
System.out.print(r.nextFloat() + " ");
if (j % 4 == 3)
System.out.println();
}
if (!(j % 4 == 3))
System.out.println();
System.out.println("\nGrab the first 1000 doubles");
r = new MersenneTwister(SEED);
for (j = 0; j < 1000; j++) {
System.out.print(r.nextDouble() + " ");
if (j % 3 == 2)
System.out.println();
}
if (!(j % 3 == 2))
System.out.println();
System.out.println("\nGrab the first 1000 gaussian doubles");
r = new MersenneTwister(SEED);
for (j = 0; j < 1000; j++) {
System.out.print(r.nextGaussian() + " ");
if (j % 3 == 2)
System.out.println();
}
if (!(j % 3 == 2))
System.out.println();
}
}
================================================
FILE: src/utility/Parallel.java
================================================
package utility;
import java.util.Collection;
import java.util.concurrent.ConcurrentHashMap;
import java.util.concurrent.ForkJoinPool;
import java.util.concurrent.RecursiveAction;
import java.util.concurrent.RecursiveTask;
/**
* Utilities for parallel computing in loops over independent tasks. This class
* provides convenient methods for parallel processing of tasks that involve
* loops over indices, in which computations for different indices are
* independent.
* <p>
* As a simple example, consider the following function that squares floats in
* one array and stores the results in a second array.
*
* <pre>
* <code>
* static void sqr(float[] a, float[] b) {
* int n = a.length;
* for (int i=0; i<n; ++i)
* b[i] = a[i]*a[i];
* }
* </code>
* </pre>
*
* A serial version of a similar function for 2D arrays is:
*
* <pre>
* <code>
* static void sqrSerial(float[][] a, float[][] b)
* {
* int n = a.length;
* for (int i=0; i<n; ++i) {
* sqr(a[i],b[i]);
* }
* </code>
* </pre>
*
* Using this class, the parallel version for 2D arrays is:
*
* <pre>
* <code>
* static void sqrParallel(final float[][] a, final float[][] b) {
* int n = a.length;
* Parallel.loop(n,new Parallel.LoopInt() {
* public void compute(int i) {
* sqr(a[i],b[i]);
* }
* });
* }
* </code>
* </pre>
*
* In the parallel version, the method {@code compute} defined by the interface
* {@code LoopInt} will be called n times for different indices i in the range
* [0,n-1]. The order of indices is both indeterminant and irrelevant because
* the computation for each index i is independent. The arrays a and b are
* declared final as required for use in the implementation of {@code LoopInt}.
* <p>
* Note: because the method {@code loop} and interface {@code LoopInt} are
* static members of this class, we can omit the class name prefix
* {@code Parallel} if we first import these names with
*
* <pre>
* <code>
* import static edu.mines.jtk.util.Parallel.*;
* </code>
* </pre>
*
* A similar method facilitates tasks that reduce a sequence of indexed values
* to one or more values. For example, given the following method:
*
* <pre>
* <code>
* static float sum(float[] a) {
* int n = a.length;
* float s = 0.0f;
* for (int i=0; i<n; ++i)
* s += a[i];
* return s;
* }
* </code>
* </pre>
*
* serial and parallel versions for 2D arrays may be written as:
*
* <pre>
* <code>
* static float sumSerial(float[][] a) {
* int n = a.length;
* float s = 0.0f;
* for (int i=0; i<n; ++i)
* s += sum(a[i]);
* return s;
* }
* </code>
* </pre>
*
* and
*
* <pre>
* <code>
* static float sumParallel(final float[][] a) {
* int n = a.length;
* return Parallel.reduce(n,new Parallel.ReduceInt<Float>() {
* public Float compute(int i) {
* return sum(a[i]);
* }
* public Float combine(Float s1, Float s2) {
* return s1+s2;
* }
* });
* }
* </code>
* </pre>
*
* In the parallel version, we implement the interface {@code ReduceInt} with
* two methods, one to {@code compute} sums of array elements and another to
* {@code combine} two such sums together. The same pattern works for other
* reduce operations. For example, with similar functions we could compute
* minimum and maximum values (in a single reduce) for any indexed sequence of
* values.
* <p>
* More general loops are supported, and are equivalent to the following serial
* code:
*
* <pre>
* <code>
* for (int i=begin; i<end; i+=step)
* // some computation that depends on i
* </code>
* </pre>
*
* The methods loop and reduce require that begin is less than end and that step
* is positive. The requirement that begin is less than end ensures that reduce
* is always well-defined. The requirement that step is positive ensures that
* the loop terminates.
* <p>
* Static methods loop and reduce submit tasks to a fork-join framework that
* maintains a pool of threads shared by all users of these methods. These
* methods recursively split tasks so that disjoint sets of indices are
* processed in parallel by different threads.
* <p>
* In addition to the three loop parameters begin, end, and step, a fourth
* parameter chunk may be specified. This chunk parameter is a threshold for
* splitting tasks so that they can be performed in parallel. If a range of
* indices to be processed is smaller than the chunk size, or if too many tasks
* have already been queued for processing, then the indices are processed
* serially. Otherwise, the range is split into two parts for processing by new
* tasks. If specified, the chunk size is a lower bound; the number of indices
* processed serially will never be lower, but may be higher, than a specified
* chunk size. The default chunk size is one.
* <p>
* The default chunk size is often sufficient, because the test for an excess
* number of queued tasks prevents tasks from being split needlessly. This test
* is especially useful when parallel loops are nested, as when looping over
* elements of multi-dimensional arrays.
* <p>
* For example, an implementation of the method {@code sqrParallel} for 3D
* arrays could simply call the 2D version listed above. Tasks will naturally
* tend to be split for outer loops, but not inner loops, thereby reducing
* overhead, time spent splitting and queueing tasks.
* <p>
* Reference: A Java Fork/Join Framework, by Doug Lea, describes the framework
* used to implement this class. This framework will be part of JDK 7.
*
* @author Dave Hale, Colorado School of Mines
* @version 2010.11.23
*/
public class Parallel
{
/** A loop body that computes something for an int index. */
public interface LoopInt
{
/**
* Computes for the specified loop index.
*
* @param i
* loop index.
*/
public void compute(int i);
}
/** A loop body that computes and returns a value for an int index. */
public interface ReduceInt<V>
{
/**
* Returns a value computed for the specified loop index.
*
* @param i
* loop index.
* @return the computed value.
*/
public V compute(int i);
/**
* Returns the combination of two specified values.
*
* @param v1
* a value.
* @param v2
* a value.
* @return the combined value.
*/
public V combine(V v1, V v2);
}
/**
* A wrapper for objects that are not thread-safe. Such objects have methods
* that cannot safely be executed concurrently in multiple threads. To use
* an unsafe object within a parallel computation, first construct an
* instance of this wrapper. Then, within the compute method, get the unsafe
* object; if null, construct and set a new unsafe object in this wrapper,
* before using the unsafe object to perform the computation. This pattern
* ensures that each thread computes using a distinct unsafe object. For
* example,
*
* <pre>
* <code>
* final Parallel.Unsafe<Worker> nts = new Parallel.Unsafe<Worker>();
* Parallel.loop(count,new Parallel.LoopInt() {
* public void compute(int i) {
* Worker w = nts.get(); // get worker for the current thread
* if (w==null) nts.set(w=new Worker()); // if null, make one
* w.work(); // the method work need not be thread-safe
* }
* });
* </code>
* </pre>
*
* This wrapper is most useful when (1) the cost of constructing an unsafe
* object is high, relative to the cost of each call to compute, and (2) the
* number of threads calling compute is significantly lower than the total
* number of such calls. Otherwise, if either of these conditions is false,
* then simply construct a new unsafe object within the compute method.
* <p>
* This wrapper works much like the Java standard class ThreadLocal, except
* that an object within this wrapper can be garbage-collected before its
* thread dies. This difference is important because fork-join worker
* threads are pooled and will typically die only when a program ends.
*/
public static class Unsafe<T>
{
/**
* Constructs a wrapper for objects that are not thread-safe.
*/
public Unsafe()
{
int initialCapacity = 16; // the default initial capacity
float loadFactor = 0.5f; // huge numbers of threads are unlikely
int concurrencyLevel = 2 * _pool.getParallelism();
_map = new ConcurrentHashMap<Thread, T>(initialCapacity,
loadFactor, concurrencyLevel);
}
/**
* Gets the object in this wrapper for the current thread.
*
* @return the object; null, of not yet set for the current thread.
*/
public T get()
{
return _map.get(Thread.currentThread());
}
/**
* Sets the object in this wrapper for the current thread.
*
* @param object
* the object.
*/
public void set(T object)
{
_map.put(Thread.currentThread(), object);
}
/**
* Returns a collection of all unsafe objects in this wrapper. This
* method is useful only after parallel loops have ended.
*
* @return the collection of unsafe objects.
*/
public Collection<T> getAll()
{
return _map.values();
}
private final ConcurrentHashMap<Thread, T> _map;
}
/**
* Performs a loop <code>for (int i=0; i<end; ++i)</code>.
*
* @param end
* the end index (not included) for the loop.
* @param body
* the loop body.
*/
public static void loop(int end, LoopInt body)
{
loop(0, end, 1, 1, body);
}
/**
* Performs a loop <code>for (int i=begin; i<end; ++i)</code>.
*
* @param begin
* the begin index for the loop; must be less than end.
* @param end
* the end index (not included) for the loop.
* @param body
* the loop body.
*/
public static void loop(int begin, int end, LoopInt body)
{
loop(begin, end, 1, 1, body);
}
/**
* Performs a loop <code>for (int i=begin; i<end; i+=step)</code>.
*
* @param begin
* the begin index for the loop; must be less than end.
* @param end
* the end index (not included) for the loop.
* @param step
* the index increment; must be positive.
* @param body
* the loop body.
*/
public static void loop(int begin, int end, int step, LoopInt body)
{
loop(begin, end, step, 1, body);
}
/**
* Performs a loop <code>for (int i=begin; i<end; i+=step)</code>.
*
* @param begin
* the begin index for the loop; must be less than end.
* @param end
* the end index (not included) for the loop.
* @param step
* the index increment; must be positive.
* @param chunk
* the chunk size; must be positive.
* @param body
* the loop body.
*/
public static void loop(int begin, int end, int step, int chunk,
LoopInt body)
{
checkArgs(begin, end, step, chunk);
if (_serial || end <= begin + chunk * step) {
for (int i = begin; i < end; i += step) {
body.compute(i);
}
}
else {
LoopIntAction task = new LoopIntAction(begin, end, step, chunk,
body);
if (LoopIntAction.inForkJoinPool()) {
task.invoke();
}
else {
_pool.invoke(task);
}
}
}
/**
* Performs a reduce <code>for (int i=0; i<end; ++i)</code>.
*
* @param end
* the end index (not included) for the loop.
* @param body
* the loop body.
* @return the computed value.
*/
public static <V> V reduce(int end, ReduceInt<V> body)
{
return reduce(0, end, 1, 1, body);
}
/**
* Performs a reduce <code>for (int i=begin; i<end; ++i)</code>.
*
* @param begin
* the begin index for the loop; must be less than end.
* @param end
* the end index (not included) for the loop.
* @param body
* the loop body.
* @return the computed value.
*/
public static <V> V reduce(int begin, int end, ReduceInt<V> body)
{
return reduce(begin, end, 1, 1, body);
}
/**
* Performs a reduce <code>for (int i=begin; i<end; i+=step)</code>.
*
* @param begin
* the begin index for the loop; must be less than end.
* @param end
* the end index (not included) for the loop.
* @param step
* the index increment; must be positive.
* @param body
* the loop body.
* @return the computed value.
*/
public static <V> V reduce(int begin, int end, int step, ReduceInt<V> body)
{
return reduce(begin, end, step, 1, body);
}
/**
* Performs a reduce <code>for (int i=begin; i<end; i+=step)</code>.
*
* @param begin
* the begin index for the loop; must be less than end.
* @param end
* the end index (not included) for the loop.
* @param step
* the index increment; must be positive.
* @param chunk
* the chunk size; must be positive.
* @param body
* the loop body.
* @return the computed value.
*/
public static <V> V reduce(int begin, int end, int step, int chunk,
ReduceInt<V> body)
{
checkArgs(begin, end, step, chunk);
if (_serial || end <= begin + chunk * step) {
V v = body.compute(begin);
for (int i = begin + step; i < end; i += step) {
V vi = body.compute(i);
v = body.combine(v, vi);
}
return v;
}
else {
ReduceIntTask<V> task = new ReduceIntTask<V>(begin, end, step,
chunk, body);
if (ReduceIntTask.inForkJoinPool()) {
return task.invoke();
}
else {
return _pool.invoke(task);
}
}
}
/**
* Enables or disables parallel processing by all methods of this class. By
* default, parallel processing is enabled. If disabled, all tasks will be
* executed on the current thread.
* <p>
* <em>Setting this flag to false disables parallel processing for all
* users of this class.</em> This method should therefore be used for
* testing and benchmarking only.
*
* @param parallel
* true, for parallel processing; false, otherwise.
*/
public static void setParallel(boolean parallel)
{
_serial = !parallel;
}
// /////////////////////////////////////////////////////////////////////////
// private
// Implementation notes:
// Each fork-join task below has a range of indices to be processed.
// If the range is less than or equal to the chunk size, or if the
// queue for the current thread holds too many tasks already, then
// simply process the range on the current thread. Otherwise, split
// the range into two parts that are approximately equal, ensuring
// that the left part is at least as large as the right part. If the
// right part is not empty, fork a new task. Then compute the left
// part in the current thread, and, if necessary, join the right part.
// Threshold for number of surplus queued tasks. Used below to
// determine whether or not to split a task into two subtasks.
private static final int NSQT = 6;
// The pool shared by all fork-join tasks created through this class.
private static ForkJoinPool _pool = new ForkJoinPool();
// Serial flag; true for no parallel processing.
private static boolean _serial = false;
/**
* Checks loop arguments.
*/
private static void checkArgs(int begin, int end, int step, int chunk)
{
argument(begin < end, "begin<end");
argument(step > 0, "step>0");
argument(chunk > 0, "chunk>0");
}
public static void argument(boolean condition, String message)
{
if (!condition)
throw new IllegalArgumentException("required condition: " + message);
}
/**
* Splits range [begin:end) into [begin:middle) and [middle:end). The
* returned middle index equals begin plus an integer multiple of step.
*/
private static int middle(int begin, int end, int step)
{
return begin + step + ((end - begin - 1) / 2) / step * step;
}
/**
* Fork-join task for parallel loop.
*/
private static class LoopIntAction
extends RecursiveAction
{
LoopIntAction(int begin, int end, int step, int chunk, LoopInt body)
{
assert begin < end : "begin < end";
_begin = begin;
_end = end;
_step = step;
_chunk = chunk;
_body = body;
}
@Override
protected void compute()
{
if (_end <= _begin + _chunk * _step
|| getSurplusQueuedTaskCount() > NSQT) {
for (int i = _begin; i < _end; i += _step) {
_body.compute(i);
}
}
else {
int middle = middle(_begin, _end, _step);
LoopIntAction l = new LoopIntAction(_begin, middle, _step,
_chunk, _body);
LoopIntAction r = (middle < _end) ? new LoopIntAction(middle,
_end, _step, _chunk, _body) : null;
if (r != null)
r.fork();
l.compute();
if (r != null)
r.join();
}
}
private final int _begin, _end, _step, _chunk;
private final LoopInt _body;
}
/**
* Fork-join task for parallel reduce.
*/
private static class ReduceIntTask<V>
extends RecursiveTask<V>
{
ReduceIntTask(int begin, int end, int step, int chunk, ReduceInt<V> body)
{
assert begin < end : "begin < end";
_begin = begin;
_e
gitextract_qkmmkaa4/
├── License.txt
├── README.md
├── build.xml
├── jar/
│ └── LFTM.jar
├── lib/
│ ├── args4j-2.0.6.jar
│ └── mallet.jar
├── src/
│ ├── LFTM.java
│ ├── eval/
│ │ └── ClusteringEval.java
│ ├── models/
│ │ ├── LFDMM.java
│ │ ├── LFDMM_Inf.java
│ │ ├── LFLDA.java
│ │ ├── LFLDA_Inf.java
│ │ └── TopicVectorOptimizer.java
│ └── utility/
│ ├── CmdArgs.java
│ ├── FuncUtils.java
│ ├── LBFGS.java
│ ├── MTRandom.java
│ ├── MersenneTwister.java
│ └── Parallel.java
└── test/
├── corpus.LABEL
├── corpus.txt
├── corpus_test.txt
└── wordVectors.txt
SYMBOL INDEX (189 symbols across 13 files)
FILE: src/LFTM.java
class LFTM (line 23) | public class LFTM
method main (line 25) | public static void main(String[] args)
method help (line 92) | public static void help(CmdLineParser parser)
FILE: src/eval/ClusteringEval.java
class ClusteringEval (line 26) | public class ClusteringEval
method ClusteringEval (line 37) | public ClusteringEval(String inPathGoldenLabelsFile, String inPathDocT...
method readGoldenLabelsFile (line 50) | public void readGoldenLabelsFile()
method readDocTopicProsFile (line 76) | public void readDocTopicProsFile()
method computePurity (line 125) | public double computePurity()
method computeNMIscore (line 145) | public double computeNMIscore()
method evaluate (line 179) | public static void evaluate(String pathGoldenLabelsFile,
method main (line 229) | public static void main(String[] args)
FILE: src/models/LFDMM.java
class LFDMM (line 34) | public class LFDMM
method LFDMM (line 100) | public LFDMM(String pathToCorpus, String pathToWordVectorsFile, int in...
method LFDMM (line 109) | public LFDMM(String pathToCorpus, String pathToWordVectorsFile, int in...
method LFDMM (line 118) | public LFDMM(String pathToCorpus, String pathToWordVectorsFile, int in...
method LFDMM (line 127) | public LFDMM(String pathToCorpus, String pathToWordVectorsFile, int in...
method LFDMM (line 136) | public LFDMM(String pathToCorpus, String pathToWordVectorsFile, int in...
method readWordVectorsFile (line 239) | public void readWordVectorsFile(String pathToWordVectorsFile)
method initialize (line 280) | public void initialize()
method initialize (line 310) | public void initialize(String pathToTopicAssignmentFile)
method inference (line 356) | public void inference()
method optimizeTopicVectors (line 391) | public void optimizeTopicVectors()
method sampleSingleIteration (line 441) | public void sampleSingleIteration()
method sampleSingleInitialIteration (line 496) | public void sampleSingleInitialIteration()
method writeParameters (line 552) | public void writeParameters()
method writeDictionary (line 575) | public void writeDictionary()
method writeIDbasedCorpus (line 586) | public void writeIDbasedCorpus()
method writeTopicAssignments (line 601) | public void writeTopicAssignments()
method writeTopicVectors (line 616) | public void writeTopicVectors()
method writeTopTopicalWords (line 629) | public void writeTopTopicalWords()
method writeTopicWordPros (line 665) | public void writeTopicWordPros()
method writeDocTopicPros (line 680) | public void writeDocTopicPros()
method write (line 708) | public void write()
method main (line 717) | public static void main(String args[])
FILE: src/models/LFDMM_Inf.java
class LFDMM_Inf (line 36) | public class LFDMM_Inf
method LFDMM_Inf (line 103) | public LFDMM_Inf(String pathToTrainingParasFile, String pathToUnseenCo...
method parseTrainingParasFile (line 208) | private HashMap<String, String> parseTrainingParasFile(
method initializeWordCount (line 231) | private void initializeWordCount(String pathToTrainingCorpus,
method readWordVectorsFile (line 298) | public void readWordVectorsFile(String pathToWordVectorsFile)
method initialize (line 341) | public void initialize()
method inference (line 371) | public void inference()
method optimizeTopicVectors (line 408) | public void optimizeTopicVectors()
method sampleSingleIteration (line 461) | public void sampleSingleIteration()
method sampleSingleInitialIteration (line 518) | public void sampleSingleInitialIteration()
method writeParameters (line 575) | public void writeParameters()
method writeDictionary (line 597) | public void writeDictionary()
method writeIDbasedCorpus (line 608) | public void writeIDbasedCorpus()
method writeTopicAssignments (line 623) | public void writeTopicAssignments()
method writeTopicVectors (line 638) | public void writeTopicVectors()
method writeTopTopicalWords (line 651) | public void writeTopTopicalWords()
method writeTopicWordPros (line 688) | public void writeTopicWordPros()
method writeDocTopicPros (line 706) | public void writeDocTopicPros()
method write (line 736) | public void write()
method main (line 745) | public static void main(String args[])
FILE: src/models/LFLDA.java
class LFLDA (line 33) | public class LFLDA
method LFLDA (line 102) | public LFLDA(String pathToCorpus, String pathToWordVectorsFile, int in...
method LFLDA (line 111) | public LFLDA(String pathToCorpus, String pathToWordVectorsFile, int in...
method LFLDA (line 120) | public LFLDA(String pathToCorpus, String pathToWordVectorsFile, int in...
method LFLDA (line 129) | public LFLDA(String pathToCorpus, String pathToWordVectorsFile, int in...
method LFLDA (line 138) | public LFLDA(String pathToCorpus, String pathToWordVectorsFile, int in...
method readWordVectorsFile (line 242) | public void readWordVectorsFile(String pathToWordVectorsFile)
method initialize (line 283) | public void initialize()
method initialize (line 314) | public void initialize(String pathToTopicAssignmentFile)
method inference (line 363) | public void inference()
method optimizeTopicVectors (line 398) | public void optimizeTopicVectors()
method sampleSingleIteration (line 448) | public void sampleSingleIteration()
method sampleSingleInitialIteration (line 497) | public void sampleSingleInitialIteration()
method writeParameters (line 547) | public void writeParameters()
method writeDictionary (line 570) | public void writeDictionary()
method writeIDbasedCorpus (line 581) | public void writeIDbasedCorpus()
method writeTopicAssignments (line 596) | public void writeTopicAssignments()
method writeTopicVectors (line 611) | public void writeTopicVectors()
method writeTopTopicalWords (line 624) | public void writeTopTopicalWords()
method writeTopicWordPros (line 660) | public void writeTopicWordPros()
method writeDocTopicPros (line 675) | public void writeDocTopicPros()
method write (line 690) | public void write()
method main (line 699) | public static void main(String args[])
FILE: src/models/LFLDA_Inf.java
class LFLDA_Inf (line 35) | public class LFLDA_Inf
method LFLDA_Inf (line 105) | public LFLDA_Inf(String pathToTrainingParasFile, String pathToUnseenCo...
method parseTrainingParasFile (line 211) | private HashMap<String, String> parseTrainingParasFile(
method initializeWordCount (line 234) | private void initializeWordCount(String pathToTrainingCorpus,
method readWordVectorsFile (line 302) | public void readWordVectorsFile(String pathToWordVectorsFile)
method initialize (line 345) | public void initialize()
method inference (line 377) | public void inference()
method optimizeTopicVectors (line 414) | public void optimizeTopicVectors()
method sampleSingleIteration (line 467) | public void sampleSingleIteration()
method sampleSingleInitialIteration (line 518) | public void sampleSingleInitialIteration()
method writeParameters (line 569) | public void writeParameters()
method writeDictionary (line 591) | public void writeDictionary()
method writeIDbasedCorpus (line 602) | public void writeIDbasedCorpus()
method writeTopicAssignments (line 617) | public void writeTopicAssignments()
method writeTopicVectors (line 632) | public void writeTopicVectors()
method writeTopTopicalWords (line 645) | public void writeTopTopicalWords()
method writeTopicWordPros (line 682) | public void writeTopicWordPros()
method writeDocTopicPros (line 700) | public void writeDocTopicPros()
method write (line 717) | public void write()
method main (line 726) | public static void main(String args[])
FILE: src/models/TopicVectorOptimizer.java
class TopicVectorOptimizer (line 17) | public class TopicVectorOptimizer
method TopicVectorOptimizer (line 37) | public TopicVectorOptimizer(double[] inTopicVector, int[] inWordCount,
method getNumParameters (line 71) | @Override
method getParameters (line 77) | @Override
method getParameter (line 84) | @Override
method setParameters (line 90) | @Override
method setParameter (line 97) | @Override
method getValueGradient (line 103) | @Override
method getValue (line 123) | @Override
method computePartitionFunction (line 142) | public double computePartitionFunction(double[] elements1,
FILE: src/utility/CmdArgs.java
class CmdArgs (line 5) | public class CmdArgs
FILE: src/utility/FuncUtils.java
class FuncUtils (line 10) | public class FuncUtils
method sortByValueDescending (line 12) | public static <K, V extends Comparable<? super V>> Map<K, V> sortByVal...
method sortByValueAscending (line 32) | public static <K, V extends Comparable<? super V>> Map<K, V> sortByVal...
method nextDiscrete (line 58) | public static int nextDiscrete(double[] probs)
method mean (line 75) | public static double mean(double[] m)
method stddev (line 83) | public static double stddev(double[] m)
FILE: src/utility/LBFGS.java
class LBFGS (line 31) | public class LBFGS
method LBFGS (line 56) | public LBFGS(Optimizable.ByGradientValue function, double inTolerance)
method getOptimizable (line 63) | @Override
method isConverged (line 69) | @Override
method setLineOptimizer (line 81) | public void setLineOptimizer(LineOptimizer.ByGradient lineOpt)
method setTolerance (line 102) | public void setTolerance(double newtol)
method setEvaluator (line 107) | public void setEvaluator(OptimizerEvaluator.ByGradient eval)
method getIteration (line 112) | public int getIteration()
method optimize (line 117) | @Override
method optimize (line 123) | @Override
method reset (line 334) | public void reset()
method push (line 347) | private void push(LinkedList l, double[] toadd)
method push (line 379) | private void push(LinkedList l, double toadd)
FILE: src/utility/MTRandom.java
class MTRandom (line 3) | public class MTRandom
method setSeed (line 8) | public static void setSeed(long seed)
method nextDouble (line 13) | public static double nextDouble()
method nextInt (line 18) | public static int nextInt(int n)
method nextBoolean (line 23) | public static boolean nextBoolean()
FILE: src/utility/MersenneTwister.java
class MersenneTwister (line 175) | public strictfp class MersenneTwister
method clone (line 210) | public Object clone()
method stateEquals (line 223) | public boolean stateEquals(Object o)
method readState (line 242) | public void readState(DataInputStream stream)
method writeState (line 259) | public void writeState(DataOutputStream stream)
method MersenneTwister (line 278) | public MersenneTwister()
method MersenneTwister (line 287) | public MersenneTwister(long seed)
method MersenneTwister (line 298) | public MersenneTwister(int[] array)
method setSeed (line 309) | synchronized public void setSeed(long seed)
method setSeed (line 343) | synchronized public void setSeed(int[] array)
method next (line 382) | synchronized protected int next(int bits)
method writeObject (line 419) | private synchronized void writeObject(ObjectOutputStream out)
method readObject (line 426) | private synchronized void readObject(ObjectInputStream in)
method nextBoolean (line 437) | public boolean nextBoolean()
method nextBoolean (line 449) | public boolean nextBoolean(float probability)
method nextBoolean (line 465) | public boolean nextBoolean(double probability)
method nextInt (line 481) | public int nextInt(int n)
method nextLong (line 503) | public long nextLong(long n)
method nextDouble (line 520) | public double nextDouble()
method nextDouble (line 551) | public double nextDouble(boolean includeZero, boolean includeOne)
method nextFloat (line 568) | public float nextFloat()
method nextFloat (line 599) | public float nextFloat(boolean includeZero, boolean includeOne)
method nextBytes (line 617) | public void nextBytes(byte[] bytes)
method nextChar (line 625) | public char nextChar()
method nextShort (line 633) | public short nextShort()
method nextByte (line 640) | public byte nextByte()
method nextGaussian (line 652) | synchronized public double nextGaussian()
method main (line 676) | public static void main(String args[])
FILE: src/utility/Parallel.java
class Parallel (line 172) | public class Parallel
type LoopInt (line 176) | public interface LoopInt
method compute (line 185) | public void compute(int i);
type ReduceInt (line 189) | public interface ReduceInt<V>
method compute (line 199) | public V compute(int i);
method combine (line 210) | public V combine(V v1, V v2);
class Unsafe (line 247) | public static class Unsafe<T>
method Unsafe (line 253) | public Unsafe()
method get (line 267) | public T get()
method set (line 278) | public void set(T object)
method getAll (line 289) | public Collection<T> getAll()
method loop (line 305) | public static void loop(int end, LoopInt body)
method loop (line 320) | public static void loop(int begin, int end, LoopInt body)
method loop (line 337) | public static void loop(int begin, int end, int step, LoopInt body)
method loop (line 356) | public static void loop(int begin, int end, int step, int chunk,
method reduce (line 386) | public static <V> V reduce(int end, ReduceInt<V> body)
method reduce (line 402) | public static <V> V reduce(int begin, int end, ReduceInt<V> body)
method reduce (line 420) | public static <V> V reduce(int begin, int end, int step, ReduceInt<V> ...
method reduce (line 440) | public static <V> V reduce(int begin, int end, int step, int chunk,
method setParallel (line 476) | public static void setParallel(boolean parallel)
method checkArgs (line 507) | private static void checkArgs(int begin, int end, int step, int chunk)
method argument (line 514) | public static void argument(boolean condition, String message)
method middle (line 524) | private static int middle(int begin, int end, int step)
class LoopIntAction (line 532) | private static class LoopIntAction
method LoopIntAction (line 535) | LoopIntAction(int begin, int end, int step, int chunk, LoopInt body)
method compute (line 545) | @Override
class ReduceIntTask (line 575) | private static class ReduceIntTask<V>
method ReduceIntTask (line 578) | ReduceIntTask(int begin, int end, int step, int chunk, ReduceInt<V> ...
method compute (line 588) | @Override
Condensed preview — 23 files, each showing path, character count, and a content snippet. Download the .json file or copy for the full structured content (236K chars).
[
{
"path": "License.txt",
"chars": 813,
"preview": "Implementations of the LF-LDA and LF-DMM latent feature topic models\n\nCopyright (C) 2015-2016 by Dat Quoc Nguyen \ndat.ng"
},
{
"path": "README.md",
"chars": 10817,
"preview": "# LF-LDA and LF-DMM latent feature topic models\n\nThe implementations of the LF-LDA and LF-DMM latent feature topic model"
},
{
"path": "build.xml",
"chars": 1621,
"preview": "<?xml version=\"1.0\" encoding=\"UTF-8\"?>\n<project name=\"LFTM\" basedir=\".\" default=\"main\">\n\n <property name=\"src.dir\" "
},
{
"path": "src/LFTM.java",
"chars": 3123,
"preview": "import models.LFDMM;\nimport models.LFDMM_Inf;\nimport models.LFLDA;\nimport models.LFLDA_Inf;\n\nimport org.kohsuke.args4j.C"
},
{
"path": "src/eval/ClusteringEval.java",
"chars": 8326,
"preview": "package eval;\n\nimport java.io.BufferedReader;\nimport java.io.BufferedWriter;\nimport java.io.File;\nimport java.io.FileRea"
},
{
"path": "src/models/LFDMM.java",
"chars": 28930,
"preview": "package models;\n\nimport java.io.BufferedReader;\nimport java.io.BufferedWriter;\nimport java.io.FileReader;\nimport java.io"
},
{
"path": "src/models/LFDMM_Inf.java",
"chars": 22976,
"preview": "package models;\n\nimport java.io.BufferedReader;\nimport java.io.BufferedWriter;\nimport java.io.FileReader;\nimport java.io"
},
{
"path": "src/models/LFLDA.java",
"chars": 27762,
"preview": "package models;\n\nimport java.io.BufferedReader;\nimport java.io.BufferedWriter;\nimport java.io.FileReader;\nimport java.io"
},
{
"path": "src/models/LFLDA_Inf.java",
"chars": 22156,
"preview": "package models;\n\nimport java.io.BufferedReader;\nimport java.io.BufferedWriter;\nimport java.io.FileReader;\nimport java.io"
},
{
"path": "src/models/TopicVectorOptimizer.java",
"chars": 4043,
"preview": "package models;\n\nimport cc.mallet.optimize.Optimizable;\nimport cc.mallet.types.MatrixOps;\n\n/**\n * Implementation of the "
},
{
"path": "src/utility/CmdArgs.java",
"chars": 1575,
"preview": "package utility;\n\nimport org.kohsuke.args4j.Option;\n\npublic class CmdArgs\n{\n\n\t@Option(name = \"-model\", usage = \"Specify "
},
{
"path": "src/utility/FuncUtils.java",
"chars": 2545,
"preview": "package utility;\n\nimport java.util.Collections;\nimport java.util.Comparator;\nimport java.util.LinkedHashMap;\nimport java"
},
{
"path": "src/utility/LBFGS.java",
"chars": 14497,
"preview": "/* Copyright (C) 2002 Univ. of Massachusetts Amherst, Computer Science Dept.\n This file is part of \"MALLET\" (MAchine L"
},
{
"path": "src/utility/MTRandom.java",
"chars": 448,
"preview": "package utility;\n\npublic class MTRandom\n{\n\n private static MersenneTwister rand = new MersenneTwister();\n\n public "
},
{
"path": "src/utility/MersenneTwister.java",
"chars": 31226,
"preview": "package utility;\n\nimport java.io.DataInputStream;\nimport java.io.DataOutputStream;\nimport java.io.IOException;\nimport ja"
},
{
"path": "src/utility/Parallel.java",
"chars": 17934,
"preview": "package utility;\n\nimport java.util.Collection;\nimport java.util.concurrent.ConcurrentHashMap;\nimport java.util.concurren"
},
{
"path": "test/corpus.LABEL",
"chars": 3100,
"preview": "apple\napple\napple\napple\napple\napple\napple\napple\napple\napple\napple\napple\napple\napple\napple\napple\napple\napple\napple\napple\n"
},
{
"path": "test/corpus.txt",
"chars": 12876,
"preview": "iphone crack iphone \nadding support iphone announced \nyoutube video guy siri pretty love \nrim made easy switch iphone ye"
},
{
"path": "test/corpus_test.txt",
"chars": 6939,
"preview": "making ipad feel ios \nnexus good feel bit guess android users android \nnice game helps search \nnice game helps search fa"
}
]
// ... and 4 more files (download for full content)
About this extraction
This page contains the full source code of the datquocnguyen/LFTM GitHub repository, extracted and formatted as plain text for AI agents and large language models (LLMs). The extraction includes 23 files (12.7 MB), approximately 55.5k tokens, and a symbol index with 189 extracted functions, classes, methods, constants, and types. Use this with OpenClaw, Claude, ChatGPT, Cursor, Windsurf, or any other AI tool that accepts text input. You can copy the full output to your clipboard or download it as a .txt file.
Extracted by GitExtract — free GitHub repo to text converter for AI. Built by Nikandr Surkov.