Repository: datquocnguyen/LFTM Branch: master Commit: 2d074e52c6b9 Files: 23 Total size: 12.7 MB Directory structure: gitextract_qkmmkaa4/ ├── License.txt ├── README.md ├── build.xml ├── jar/ │ └── LFTM.jar ├── lib/ │ ├── args4j-2.0.6.jar │ └── mallet.jar ├── src/ │ ├── LFTM.java │ ├── eval/ │ │ └── ClusteringEval.java │ ├── models/ │ │ ├── LFDMM.java │ │ ├── LFDMM_Inf.java │ │ ├── LFLDA.java │ │ ├── LFLDA_Inf.java │ │ └── TopicVectorOptimizer.java │ └── utility/ │ ├── CmdArgs.java │ ├── FuncUtils.java │ ├── LBFGS.java │ ├── MTRandom.java │ ├── MersenneTwister.java │ └── Parallel.java └── test/ ├── corpus.LABEL ├── corpus.txt ├── corpus_test.txt └── wordVectors.txt ================================================ FILE CONTENTS ================================================ ================================================ FILE: License.txt ================================================ Implementations of the LF-LDA and LF-DMM latent feature topic models Copyright (C) 2015-2016 by Dat Quoc Nguyen dat.nguyen@students.mq.edu.au Department of Computing, Macquarie University, Australia This program is free software; you can redistribute it and/or modify it under the terms of the GNU General Public License as published by the Free Software Foundation; either version 2 of the License, or (at your option) any later version. This program is distributed in the hope that it will be useful, but WITHOUT ANY WARRANTY; without even the implied warranty of MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the GNU General Public License for more details. You should have received a copy of the GNU General Public License along with this program. If not, see ================================================ FILE: README.md ================================================ # LF-LDA and LF-DMM latent feature topic models The implementations of the LF-LDA and LF-DMM latent feature topic models, as described in my TACL paper: Dat Quoc Nguyen, Richard Billingsley, Lan Du and Mark Johnson. [Improving Topic Models with Latent Feature Word Representations](https://tacl2013.cs.columbia.edu/ojs/index.php/tacl/article/view/582/158). Transactions of the Association for Computational Linguistics, vol. 3, pp. 299-313, 2015. [[.bib]](http://web.science.mq.edu.au/~dqnguyen/papers/TACL.bib) [[Datasets]](http://web.science.mq.edu.au/~dqnguyen/papers/TACL-datasets.zip) [[Example_20Newsgroups_20Topics_Top50Words]](http://web.science.mq.edu.au/~dqnguyen/papers/TACL_TopWords_N20_20Topics.zip) The implementations of the LDA and DMM topic models are available at [http://jldadmm.sourceforge.net/](http://jldadmm.sourceforge.net/) ## Usage This section describes the usage of the implementations in command line or terminal, using the pre-compiled `LFTM.jar` file. Here, it is expected that Java 1.7+ is already set to run in command line or terminal (for example: adding Java to the `path` environment variable in Windows OS). The pre-compiled `LFTM.jar` file and source codes are in the `jar` and `src` folders, respectively. Users can recompile the source codes by simply running `ant` (it is also expected that `ant` is already installed). In addition, the users can find input examples in the `test` folder. #### File format of input topic-modeling corpus Similar to the `corpus.txt` file in the `test` folder, each line in the input topic-modeling corpus represents a document. Here, a document is a sequence words/tokens separated by white space characters. The users should preprocess the input topic-modeling corpus before training the topic models, for example: down-casing, removing non-alphabetic characters and stop-words, removing words shorter than 3 characters and words appearing less than a certain times. #### Format of input word-vector file Similar to the `wordVectors.txt` file in the `test` folder, each line in the input word-vector file starts with a word type which is followed by a vector representation. To obtain the vector representations of words, the users can use `the pre-trained word vectors learned from large external corpora` OR `the word vectors which are trained on the input topic-modeling corpus`. In case of using the pre-trained word vectors learned from the large external corpora, the users have to remove words in the input topic-modeling corpus, in which these words are not found in the input word-vector file. Some sets of the pre-trained word vectors can be found at: [Word2Vec: https://code.google.com/p/word2vec/](https://code.google.com/p/word2vec/) [Glove: http://nlp.stanford.edu/projects/glove/](http://nlp.stanford.edu/projects/glove/) If the input topic-modeling corpus is too domain-specific, the domain of the external corpus (from which the word vectors are derived) should not be too different to that of the input topic-modeling corpus. For example, when applying to the biomedical domain, the users may use Word2Vec or Glove to learn 50 or 100-dimensional word vectors on the large external MEDLINE corpus instead of using the pre-trained Word2Vec or Glove word vectors. ### Training LF-LDA and LF-DMM `$ java [-Xmx2G] -jar jar/LFTM.jar –model -corpus -vectors [-ntopics ] [-alpha ] [-beta ] [-lambda ] [-initers ] [-niters ] [-twords ] [-name ] [-sstep ]` where hyper-parameters in [ ] are optional. * `-model`: Specify the topic model. * `-corpus`: Specify the path to the input training corpus file. * `-vectors`: Specify the path to the file containing word vectors. * `-ntopics `: Specify the number of topics. The default value is 20. * `-alpha `: Specify the hyper-parameter alpha. Following [1, 2], the default value is 0.1. * `-beta `: Specify the hyper-parameter beta. The default value is 0.01. Following [2], you might also want to try beta value of 0.1 for short texts. * `-lambda `: Specify the mixture weight lambda (0.0 < lambda <= 1.0). Set the mixture weight lambda to be 1.0 to obtain the best topic coherence. In case of document clustering/classification evaluation, fine-tune this parameter to obtain the highest results if you have time; otherwise try both values 0.6 and 1.0 (I would suggest to set lambda 0.6 for normal text corpora and 1.0 for short text corpora if you don't have time to try both 0.6 and 1.0). * `-initers `: Specify the number of initial sampling iterations to separate the counts for the latent feature component and the Dirichlet multinomial component. The default value is 2000. * `-niters `: Specify the number of sampling iterations for the latent feature topic models. The default value is 200. * `-twords `: Specify the number of the most probable topical words. The default value is 20. * `-name `: Specify a name to the topic modeling experiment. The default value is “model”. * `-sstep `: Specify a step to save the sampling output (`-sstep` value < `-niters` value). The default value is 0 (i.e. only saving the output from the last sample). NOTE that (topic vectors are learned in parallel, so) running LFTM code with multiple CPU/core machine to obtain a significantly faster training process, e.g. using a multi-core computer, or set number of CPUs requested for a remote job to be equal to number of topics. Examples: `$ java -jar jar/LFTM.jar -model LFLDA -corpus test/corpus.txt -vectors test/wordVectors.txt -ntopics 4 -alpha 0.1 -beta 0.01 -lambda 0.6 -initers 500 -niters 50 -name testLFLDA` Basically, with this command we run 500 `LDA` sampling iterations (i.e., `-initers 500`) for initialization and then run 50 `LF-LDA` sampling iterations (i.e., `-niters 50`). The output files are saved in the same folder as the input training corpus file, in this case in the `test` folder. We have output files of `testLFLDA.theta`, `testLFLDA.phi`, `testLFLDA.topWords`, `testLFLDA.topicAssignments` and `testLFLDA.paras`, referring to the document-to-topic distributions, topic-to-word distributions, top topical words, topic assignments and model hyper-parameters, respectively. Similarly, we perform: `$ java -jar jar/LFTM.jar -model LFDMM -corpus test/corpus.txt -vectors test/wordVectors.txt -ntopics 4 -alpha 0.1 -beta 0.1 -lambda 1.0 -initers 500 -niters 50 -name testLFDMM` We have output files of `testLFDMM.theta`, `testLFDMM.phi`, `testLFDMM.topWords`, `testLFDMM.topicAssignments` and `testLFDMM.paras`. In the LF-LDA and LF-DMM latent feature topic models, a word is generated by the latent feature topic-to-word component OR by the topic-to-word Dirichlet multinomial component. In practical implementation, instead of using a binary selection variable to record that, I simply add a value of the number of topics to the actual topic assignment value. For example with 20 topics, the output topic assignment `3 23 4 4 24 3 23 3 23 3 23` for a document means that the first word in the document is generated from topic 3 by the latent feature topic-to-word component. The second word is also generated from the topic `23 - 20 = 3`, but by the topic-to-word Dirichlet multinomial component. It is similar for the remaining words in the document. ### Document clustering evaluation Here, we treat each topic as a cluster, and we assign every document the topic with the highest probability given the document. To get the clustering scores of Purity and normalized mutual information, we perform: `$ java –jar jar/LFTM.jar –model Eval –label -dir -prob ` * `–label`: Specify the path to the ground truth label file. Each line in this label file contains the golden label of the corresponding document in the input training corpus. See the `corpus.LABEL` and `corpus.txt` files in the `test` folder. * `-dir`: Specify the path to the directory containing document-to-topic distribution files. * `-prob`: Specify a document-to-topic distribution file or a group of document-to-topic distribution files in the specified directory. Examples: The command `$ java -jar jar/LFTM.jar -model Eval -label test/corpus.LABEL -dir test -prob testLFLDA.theta` will produce the clustering score for the `testLFLDA.theta` file. The command `$ java -jar jar/LFTM.jar -model Eval -label test/corpus.LABEL -dir test -prob testLFDMM.theta` will produce the clustering score for `testLFDMM.theta` file. The command `$ java -jar jar/LFTM.jar -model Eval -label test/corpus.LABEL -dir test -prob theta` will produce the clustering scores for all the document-to-topic distribution files having names ended by `theta`. In this case, the distribution files are `testLFLDA.theta` and `testLFDMM.theta`. It also provides the mean and standard deviation of the clustering scores. ### Inference of topic distribution on unseen corpus To infer topics on an unseen/new corpus using a pre-trained LF-LDA/LF-DMM topic model, we perform: `$ java -jar jar/LFTM.jar -model -paras -corpus [-initers ] [-niters ] [-twords ] [-name ] [-sstep ]` * `-paras`: Specify the path to the hyper-parameter file produced by the pre-trained LF-LDA/LF-DMM topic model. Examples: `$ java -jar jar/LFTM.jar -model LFLDAinf -paras test/testLFLDA.paras -corpus test/corpus_test.txt -initers 500 -niters 50 -name testLFLDAinf` `$ java -jar jar/LFTM.jar -model LFDMMinf -paras test/testLFDMM.paras -corpus test/corpus_test.txt -initers 500 -niters 50 -name testLFDMMinf` ## Acknowledgments The LF-LDA and LF-DMM implementations used utilities including the LBFGS implementation from [MALLET toolkit](http://mallet.cs.umass.edu/), the random number generator in [Java version of MersenneTwister](http://cs.gmu.edu/~sean/research/), the `Parallel.java` utility from [Mines Java Toolkit](http://dhale.github.io/jtk/api/edu/mines/jtk/util/Parallel.html) and the [Java command line arguments parser](http://args4j.kohsuke.org/sample.html). I would like to thank the authors of the mentioned utilities for sharing the codes. ## References [1] Yue Lu, Qiaozhu Mei, and ChengXiang Zhai. 2011. Investigating task performance of probabilistic topic models: an empirical study of PLSA and LDA. Information Retrieval, 14:178–203. [2] Jianhua Yin and Jianyong Wang. 2014. A Dirichlet Multinomial Mixture Model-based Approach for Short Text Clustering. In Proceedings of the 20th ACM SIGKDD International Conference on Knowledge Discovery and Data Mining, pages 233–242. ================================================ FILE: build.xml ================================================ ================================================ FILE: src/LFTM.java ================================================ import models.LFDMM; import models.LFDMM_Inf; import models.LFLDA; import models.LFLDA_Inf; import org.kohsuke.args4j.CmdLineException; import org.kohsuke.args4j.CmdLineParser; import utility.CmdArgs; import eval.ClusteringEval; /** * Implementations of the LF-LDA and LF-DMM latent feature topic models, using * collapsed Gibbs sampling, as described in: * * Dat Quoc Nguyen, Richard Billingsley, Lan Du and Mark Johnson. 2015. * Improving Topic Models with Latent Feature Word Representations. Transactions * of the Association for Computational Linguistics, vol. 3, pp. 299-313. * * @author Dat Quoc Nguyen * */ public class LFTM { public static void main(String[] args) { CmdArgs cmdArgs = new CmdArgs(); CmdLineParser parser = new CmdLineParser(cmdArgs); try { parser.parseArgument(args); if (cmdArgs.model.equals("LFLDA")) { LFLDA lflda = new LFLDA(cmdArgs.corpus, cmdArgs.vectors, cmdArgs.ntopics, cmdArgs.alpha, cmdArgs.beta, cmdArgs.lambda, cmdArgs.initers, cmdArgs.niters, cmdArgs.twords, cmdArgs.expModelName, cmdArgs.initTopicAssgns, cmdArgs.savestep); lflda.inference(); } else if (cmdArgs.model.equals("LFDMM")) { LFDMM lfdmm = new LFDMM(cmdArgs.corpus, cmdArgs.vectors, cmdArgs.ntopics, cmdArgs.alpha, cmdArgs.beta, cmdArgs.lambda, cmdArgs.initers, cmdArgs.niters, cmdArgs.twords, cmdArgs.expModelName, cmdArgs.initTopicAssgns, cmdArgs.savestep); lfdmm.inference(); } else if (cmdArgs.model.equals("LFLDAinf")) { LFLDA_Inf lfldaInf = new LFLDA_Inf(cmdArgs.paras, cmdArgs.corpus, cmdArgs.initers, cmdArgs.niters, cmdArgs.twords, cmdArgs.expModelName, cmdArgs.savestep); lfldaInf.inference(); } else if (cmdArgs.model.equals("LFDMMinf")) { LFDMM_Inf lfdmmInf = new LFDMM_Inf(cmdArgs.paras, cmdArgs.corpus, cmdArgs.initers, cmdArgs.niters, cmdArgs.twords, cmdArgs.expModelName, cmdArgs.savestep); lfdmmInf.inference(); } else if (cmdArgs.model.equals("Eval")) { ClusteringEval.evaluate(cmdArgs.labelFile, cmdArgs.dir, cmdArgs.prob); } else { System.out .println("Error: Option \"-model\" must get \"LFLDA\" or \"LFDMM\" or \"LFLDAinf\" or \"LFDMMinf\" or \"Eval\""); System.out.println("\tLFLDA: Specify the LF-LDA topic model"); System.out.println("\tLFDMM: Specify the LF-DMM topic model"); System.out .println("\tLFLDAinf: Infer topics for unseen corpus using a pre-trained LF-LDA model"); System.out .println("\tLFDMMinf: Infer topics for unseen corpus using a pre-trained LF-DMM model"); System.out .println("\tEval: Specify the document clustering evaluation"); help(parser); return; } } catch (CmdLineException cle) { System.out.println("Error: " + cle.getMessage()); help(parser); return; } catch (Exception e) { System.out.println("Error: " + e.getMessage()); e.printStackTrace(); return; } } public static void help(CmdLineParser parser) { System.out.println("java -jar LFTM.jar [options ...] [arguments...]"); parser.printUsage(System.out); } } ================================================ FILE: src/eval/ClusteringEval.java ================================================ package eval; import java.io.BufferedReader; import java.io.BufferedWriter; import java.io.File; import java.io.FileReader; import java.io.FileWriter; import java.util.ArrayList; import java.util.HashMap; import java.util.HashSet; import java.util.List; import java.util.Set; import utility.FuncUtils; /** * Implementation of the Purity and NMI clustering evaluation scores, as described in Section 16.3 * in: * * Christopher D. Manning, Prabhakar Raghavan, and Hinrich Sch¨utze. 2008. Introduction to * Information Retrieval. Cambridge University Press. * * @author: Dat Quoc Nguyen */ public class ClusteringEval { String pathDocTopicProsFile; String pathGoldenLabelsFile; HashMap> goldenClusers; HashMap> outputClusers; int numDocs; public ClusteringEval(String inPathGoldenLabelsFile, String inPathDocTopicProsFile) throws Exception { pathDocTopicProsFile = inPathDocTopicProsFile; pathGoldenLabelsFile = inPathGoldenLabelsFile; goldenClusers = new HashMap>(); outputClusers = new HashMap>(); readGoldenLabelsFile(); readDocTopicProsFile(); } public void readGoldenLabelsFile() throws Exception { System.out.println("Reading golden labels file " + pathGoldenLabelsFile); int id = 0; BufferedReader br = null; try { br = new BufferedReader(new FileReader(pathGoldenLabelsFile)); for (String label; (label = br.readLine()) != null;) { label = label.trim(); Set ids = new HashSet(); if (goldenClusers.containsKey(label)) ids = goldenClusers.get(label); ids.add(id); goldenClusers.put(label, ids); id += 1; } } catch (Exception e) { e.printStackTrace(); } numDocs = id; } public void readDocTopicProsFile() throws Exception { System.out.println("Reading document-to-topic distribution file " + pathDocTopicProsFile); HashMap docLabelOutput = new HashMap(); int docIndex = 0; BufferedReader br = null; try { br = new BufferedReader(new FileReader(pathDocTopicProsFile)); for (String docTopicProbs; (docTopicProbs = br.readLine()) != null;) { String[] pros = docTopicProbs.trim().split("\\s+"); double maxPro = 0.0; int index = -1; for (int topicIndex = 0; topicIndex < pros.length; topicIndex++) { double pro = new Double(pros[topicIndex]); if (pro > maxPro) { maxPro = pro; index = topicIndex; } } docLabelOutput.put(docIndex, "Topic_" + new Integer(index).toString()); docIndex++; } } catch (Exception e) { e.printStackTrace(); } if (numDocs != docIndex) { System.out .println("Error: the number of documents is different to the number of labels!"); throw new Exception(); } for (Integer id : docLabelOutput.keySet()) { String label = docLabelOutput.get(id); Set ids = new HashSet(); if (outputClusers.containsKey(label)) ids = outputClusers.get(label); ids.add(id); outputClusers.put(label, ids); } } public double computePurity() { int count = 0; for (String label : outputClusers.keySet()) { Set docs = outputClusers.get(label); int correctAssignedDocNum = 0; for (String goldenLabel : goldenClusers.keySet()) { Set goldenDocs = goldenClusers.get(goldenLabel); Set outputDocs = new HashSet(docs); outputDocs.retainAll(goldenDocs); if (outputDocs.size() >= correctAssignedDocNum) correctAssignedDocNum = outputDocs.size(); } count += correctAssignedDocNum; } double value = count * 1.0 / numDocs; System.out.println("\tPurity accuracy: " + value); return value; } public double computeNMIscore() { double MIscore = 0.0; for (String label : outputClusers.keySet()) { Set docs = outputClusers.get(label); for (String goldenLabel : goldenClusers.keySet()) { Set goldenDocs = goldenClusers.get(goldenLabel); Set outputDocs = new HashSet(docs); outputDocs.retainAll(goldenDocs); double numCorrectAssignedDocs = outputDocs.size() * 1.0; if (numCorrectAssignedDocs == 0.0) continue; MIscore += (numCorrectAssignedDocs / numDocs) * Math.log(numCorrectAssignedDocs * numDocs / (docs.size() * goldenDocs.size())); } } double entropy = 0.0; for (String label : outputClusers.keySet()) { Set docs = outputClusers.get(label); entropy += (-1.0 * docs.size() / numDocs) * Math.log(1.0 * docs.size() / numDocs); } for (String label : goldenClusers.keySet()) { Set docs = goldenClusers.get(label); entropy += (-1.0 * docs.size() / numDocs) * Math.log(1.0 * docs.size() / numDocs); } double value = 2 * MIscore / entropy; System.out.println("\tNMI score: " + value); return value; } public static void evaluate(String pathGoldenLabelsFile, String pathToFolderOfDocTopicProsFiles, String suffix) throws Exception { BufferedWriter writer = new BufferedWriter(new FileWriter(pathToFolderOfDocTopicProsFiles + "/" + suffix + ".PurityNMI")); writer.write("Golden-labels in: " + pathGoldenLabelsFile + "\n\n"); File[] files = new File(pathToFolderOfDocTopicProsFiles).listFiles(); List purity = new ArrayList(), nmi = new ArrayList(); for (File file : files) { if (!file.getName().endsWith(suffix)) continue; writer.write("Results for: " + file.getAbsolutePath() + "\n"); ClusteringEval dce = new ClusteringEval(pathGoldenLabelsFile, file.getAbsolutePath()); double value = dce.computePurity(); writer.write("\tPurity: " + value + "\n"); purity.add(value); value = dce.computeNMIscore(); writer.write("\tNMI: " + value + "\n"); nmi.add(value); } if (purity.size() == 0 || nmi.size() == 0) { System.out.println("Error: There is no file ending with " + suffix); throw new Exception(); } double[] purityValues = new double[purity.size()]; double[] nmiValues = new double[nmi.size()]; for (int i = 0; i < purity.size(); i++) purityValues[i] = purity.get(i).doubleValue(); for (int i = 0; i < nmi.size(); i++) nmiValues[i] = nmi.get(i).doubleValue(); writer.write("\n---\nMean purity: " + FuncUtils.mean(purityValues) + ", standard deviation: " + FuncUtils.stddev(purityValues)); writer.write("\nMean NMI: " + FuncUtils.mean(nmiValues) + ", standard deviation: " + FuncUtils.stddev(nmiValues)); System.out.println("---\nMean purity: " + FuncUtils.mean(purityValues) + ", standard deviation: " + FuncUtils.stddev(purityValues)); System.out.println("Mean NMI: " + FuncUtils.mean(nmiValues) + ", standard deviation: " + FuncUtils.stddev(nmiValues)); writer.close(); } public static void main(String[] args) throws Exception { ClusteringEval.evaluate("test/corpus.LABEL", "test", "theta"); } } ================================================ FILE: src/models/LFDMM.java ================================================ package models; import java.io.BufferedReader; import java.io.BufferedWriter; import java.io.FileReader; import java.io.FileWriter; import java.io.IOException; import java.util.ArrayList; import java.util.HashMap; import java.util.List; import java.util.Map; import java.util.Set; import java.util.TreeMap; import utility.FuncUtils; import utility.LBFGS; import utility.Parallel; import cc.mallet.optimize.InvalidOptimizableException; import cc.mallet.optimize.Optimizer; import cc.mallet.types.MatrixOps; import cc.mallet.util.Randoms; /** * Implementation of the LF-DMM latent feature topic model, using collapsed Gibbs sampling, as * described in: * * Dat Quoc Nguyen, Richard Billingsley, Lan Du and Mark Johnson. 2015. Improving Topic Models with * Latent Feature Word Representations. Transactions of the Association for Computational * Linguistics, vol. 3, pp. 299-313. * * @author Dat Quoc Nguyen */ public class LFDMM { public double alpha; // Hyper-parameter alpha public double beta; // Hyper-parameter alpha // public double alphaSum; // alpha * numTopics public double betaSum; // beta * vocabularySize public int numTopics; // Number of topics public int topWords; // Number of most probable words for each topic public double lambda; // Mixture weight value public int numInitIterations; public int numIterations; // Number of EM-style sampling iterations public List> corpus; // Word ID-based corpus public List> topicAssignments; // Topics assignments for words // in the corpus public int numDocuments; // Number of documents in the corpus public int numWordsInCorpus; // Number of words in the corpus public HashMap word2IdVocabulary; // Vocabulary to get ID // given a word public HashMap id2WordVocabulary; // Vocabulary to get word // given an ID public int vocabularySize; // The number of word types in the corpus // Number of documents assigned to a topic public int[] docTopicCount; // numTopics * vocabularySize matrix // Given a topic: number of times a word type generated from the topic by // the Dirichlet multinomial component public int[][] topicWordCountDMM; // Total number of words generated from each topic by the Dirichlet // multinomial component public int[] sumTopicWordCountDMM; // numTopics * vocabularySize matrix // Given a topic: number of times a word type generated from the topic by // the latent feature component public int[][] topicWordCountLF; // Total number of words generated from each topic by the latent feature // component public int[] sumTopicWordCountLF; // Double array used to sample a topic public double[] multiPros; // Path to the directory containing the corpus public String folderPath; // Path to the topic modeling corpus public String corpusPath; public String vectorFilePath; public double[][] wordVectors; // Vector representations for words public double[][] topicVectors;// Vector representations for topics public int vectorSize; // Number of vector dimensions public double[][] dotProductValues; public double[][] expDotProductValues; public double[] sumExpValues; // Partition function values public final double l2Regularizer = 0.01; // L2 regularizer value for learning topic vectors public final double tolerance = 0.05; // Tolerance value for LBFGS convergence public String expName = "LFDMM"; public String orgExpName = "LFDMM"; public String tAssignsFilePath = ""; public int savestep = 0; public LFDMM(String pathToCorpus, String pathToWordVectorsFile, int inNumTopics, double inAlpha, double inBeta, double inLambda, int inNumInitIterations, int inNumIterations, int inTopWords) throws Exception { this(pathToCorpus, pathToWordVectorsFile, inNumTopics, inAlpha, inBeta, inLambda, inNumInitIterations, inNumIterations, inTopWords, "LFDMM"); } public LFDMM(String pathToCorpus, String pathToWordVectorsFile, int inNumTopics, double inAlpha, double inBeta, double inLambda, int inNumInitIterations, int inNumIterations, int inTopWords, String inExpName) throws Exception { this(pathToCorpus, pathToWordVectorsFile, inNumTopics, inAlpha, inBeta, inLambda, inNumInitIterations, inNumIterations, inTopWords, inExpName, "", 0); } public LFDMM(String pathToCorpus, String pathToWordVectorsFile, int inNumTopics, double inAlpha, double inBeta, double inLambda, int inNumInitIterations, int inNumIterations, int inTopWords, String inExpName, String pathToTAfile) throws Exception { this(pathToCorpus, pathToWordVectorsFile, inNumTopics, inAlpha, inBeta, inLambda, inNumInitIterations, inNumIterations, inTopWords, inExpName, pathToTAfile, 0); } public LFDMM(String pathToCorpus, String pathToWordVectorsFile, int inNumTopics, double inAlpha, double inBeta, double inLambda, int inNumInitIterations, int inNumIterations, int inTopWords, String inExpName, int inSaveStep) throws Exception { this(pathToCorpus, pathToWordVectorsFile, inNumTopics, inAlpha, inBeta, inLambda, inNumInitIterations, inNumIterations, inTopWords, inExpName, "", inSaveStep); } public LFDMM(String pathToCorpus, String pathToWordVectorsFile, int inNumTopics, double inAlpha, double inBeta, double inLambda, int inNumInitIterations, int inNumIterations, int inTopWords, String inExpName, String pathToTAfile, int inSaveStep) throws Exception { alpha = inAlpha; beta = inBeta; lambda = inLambda; numTopics = inNumTopics; numIterations = inNumIterations; numInitIterations = inNumInitIterations; topWords = inTopWords; savestep = inSaveStep; expName = inExpName; orgExpName = expName; vectorFilePath = pathToWordVectorsFile; corpusPath = pathToCorpus; folderPath = pathToCorpus.substring(0, Math.max(pathToCorpus.lastIndexOf("/"), pathToCorpus.lastIndexOf("\\")) + 1); System.out.println("Reading topic modeling corpus: " + pathToCorpus); word2IdVocabulary = new HashMap(); id2WordVocabulary = new HashMap(); corpus = new ArrayList>(); numDocuments = 0; numWordsInCorpus = 0; BufferedReader br = null; try { int indexWord = -1; br = new BufferedReader(new FileReader(pathToCorpus)); for (String doc; (doc = br.readLine()) != null;) { if (doc.trim().length() == 0) continue; String[] words = doc.trim().split("\\s+"); List document = new ArrayList(); for (String word : words) { if (word2IdVocabulary.containsKey(word)) { document.add(word2IdVocabulary.get(word)); } else { indexWord += 1; word2IdVocabulary.put(word, indexWord); id2WordVocabulary.put(indexWord, word); document.add(indexWord); } } numDocuments++; numWordsInCorpus += document.size(); corpus.add(document); } } catch (Exception e) { e.printStackTrace(); } vocabularySize = word2IdVocabulary.size(); docTopicCount = new int[numTopics]; topicWordCountDMM = new int[numTopics][vocabularySize]; sumTopicWordCountDMM = new int[numTopics]; topicWordCountLF = new int[numTopics][vocabularySize]; sumTopicWordCountLF = new int[numTopics]; multiPros = new double[numTopics]; for (int i = 0; i < numTopics; i++) { multiPros[i] = 1.0 / numTopics; } // alphaSum = numTopics * alpha; betaSum = vocabularySize * beta; readWordVectorsFile(vectorFilePath); topicVectors = new double[numTopics][vectorSize]; dotProductValues = new double[numTopics][vocabularySize]; expDotProductValues = new double[numTopics][vocabularySize]; sumExpValues = new double[numTopics]; System.out .println("Corpus size: " + numDocuments + " docs, " + numWordsInCorpus + " words"); System.out.println("Vocabuary size: " + vocabularySize); System.out.println("Number of topics: " + numTopics); System.out.println("alpha: " + alpha); System.out.println("beta: " + beta); System.out.println("lambda: " + lambda); System.out.println("Number of initial sampling iterations: " + numInitIterations); System.out.println("Number of EM-style sampling iterations for the LF-DMM model: " + numIterations); System.out.println("Number of top topical words: " + topWords); tAssignsFilePath = pathToTAfile; if (tAssignsFilePath.length() > 0) initialize(tAssignsFilePath); else initialize(); } public void readWordVectorsFile(String pathToWordVectorsFile) throws Exception { System.out.println("Reading word vectors from word-vectors file " + pathToWordVectorsFile + "..."); BufferedReader br = null; try { br = new BufferedReader(new FileReader(pathToWordVectorsFile)); String[] elements = br.readLine().trim().split("\\s+"); vectorSize = elements.length - 1; wordVectors = new double[vocabularySize][vectorSize]; String word = elements[0]; if (word2IdVocabulary.containsKey(word)) { for (int j = 0; j < vectorSize; j++) { wordVectors[word2IdVocabulary.get(word)][j] = new Double(elements[j + 1]); } } for (String line; (line = br.readLine()) != null;) { elements = line.trim().split("\\s+"); word = elements[0]; if (word2IdVocabulary.containsKey(word)) { for (int j = 0; j < vectorSize; j++) { wordVectors[word2IdVocabulary.get(word)][j] = new Double(elements[j + 1]); } } } } catch (Exception e) { e.printStackTrace(); } for (int i = 0; i < vocabularySize; i++) { if (MatrixOps.absNorm(wordVectors[i]) == 0.0) { System.out.println("The word \"" + id2WordVocabulary.get(i) + "\" doesn't have a corresponding vector!!!"); throw new Exception(); } } } public void initialize() throws IOException { System.out.println("Randomly initialzing topic assignments ..."); topicAssignments = new ArrayList>(); for (int docId = 0; docId < numDocuments; docId++) { List topics = new ArrayList(); int topic = FuncUtils.nextDiscrete(multiPros); docTopicCount[topic] += 1; int docSize = corpus.get(docId).size(); for (int j = 0; j < docSize; j++) { int wordId = corpus.get(docId).get(j); boolean component = new Randoms().nextBoolean(); int subtopic = topic; if (!component) { // Generated from the latent feature component topicWordCountLF[topic][wordId] += 1; sumTopicWordCountLF[topic] += 1; } else {// Generated from the Dirichlet multinomial component topicWordCountDMM[topic][wordId] += 1; sumTopicWordCountDMM[topic] += 1; subtopic = subtopic + numTopics; } topics.add(subtopic); } topicAssignments.add(topics); } } public void initialize(String pathToTopicAssignmentFile) throws Exception { System.out.println("Reading topic-assignment file: " + pathToTopicAssignmentFile); topicAssignments = new ArrayList>(); BufferedReader br = null; try { br = new BufferedReader(new FileReader(pathToTopicAssignmentFile)); int docId = 0; int numWords = 0; for (String line; (line = br.readLine()) != null;) { String[] strTopics = line.trim().split("\\s+"); List topics = new ArrayList(); int topic = new Integer(strTopics[0]) % numTopics; docTopicCount[topic] += 1; for (int j = 0; j < strTopics.length; j++) { int wordId = corpus.get(docId).get(j); int subtopic = new Integer(strTopics[j]); if (subtopic == topic) { topicWordCountLF[topic][wordId] += 1; sumTopicWordCountLF[topic] += 1; } else { topicWordCountDMM[topic][wordId] += 1; sumTopicWordCountDMM[topic] += 1; } topics.add(subtopic); numWords++; } topicAssignments.add(topics); docId++; } if ((docId != numDocuments) || (numWords != numWordsInCorpus)) { System.out .println("The topic modeling corpus and topic assignment file are not consistent!!!"); throw new Exception(); } } catch (Exception e) { e.printStackTrace(); } } public void inference() throws IOException { System.out.println("Running Gibbs sampling inference: "); for (int iter = 1; iter <= numInitIterations; iter++) { System.out.println("\tInitial sampling iteration: " + (iter)); sampleSingleInitialIteration(); } for (int iter = 1; iter <= numIterations; iter++) { System.out.println("\tLFDMM sampling iteration: " + (iter)); optimizeTopicVectors(); sampleSingleIteration(); if ((savestep > 0) && (iter % savestep == 0) && (iter < numIterations)) { System.out.println("\t\tSaving the output from the " + iter + "^{th} sample"); expName = orgExpName + "-" + iter; write(); } } expName = orgExpName; writeParameters(); System.out.println("Writing output from the last sample ..."); write(); System.out.println("Sampling completed!"); } public void optimizeTopicVectors() { System.out.println("\t\tEstimating topic vectors ..."); sumExpValues = new double[numTopics]; dotProductValues = new double[numTopics][vocabularySize]; expDotProductValues = new double[numTopics][vocabularySize]; Parallel.loop(numTopics, new Parallel.LoopInt() { @Override public void compute(int topic) { int rate = 1; boolean check = true; while (check) { double l2Value = l2Regularizer * rate; try { TopicVectorOptimizer optimizer = new TopicVectorOptimizer( topicVectors[topic], topicWordCountLF[topic], wordVectors, l2Value); Optimizer gd = new LBFGS(optimizer, tolerance); gd.optimize(600); optimizer.getParameters(topicVectors[topic]); sumExpValues[topic] = optimizer.computePartitionFunction( dotProductValues[topic], expDotProductValues[topic]); check = false; if (sumExpValues[topic] == 0 || Double.isInfinite(sumExpValues[topic])) { double max = -1000000000.0; for (int index = 0; index < vocabularySize; index++) { if (dotProductValues[topic][index] > max) max = dotProductValues[topic][index]; } for (int index = 0; index < vocabularySize; index++) { expDotProductValues[topic][index] = Math .exp(dotProductValues[topic][index] - max); sumExpValues[topic] += expDotProductValues[topic][index]; } } } catch (InvalidOptimizableException e) { e.printStackTrace(); check = true; } rate = rate * 10; } } }); } public void sampleSingleIteration() { for (int dIndex = 0; dIndex < numDocuments; dIndex++) { List document = corpus.get(dIndex); int docSize = document.size(); int topic = topicAssignments.get(dIndex).get(0) % numTopics; docTopicCount[topic] = docTopicCount[topic] - 1; for (int wIndex = 0; wIndex < docSize; wIndex++) { int word = document.get(wIndex);// wordId int subtopic = topicAssignments.get(dIndex).get(wIndex); if (subtopic == topic) { topicWordCountLF[topic][word] -= 1; sumTopicWordCountLF[topic] -= 1; } else { topicWordCountDMM[topic][word] -= 1; sumTopicWordCountDMM[topic] -= 1; } } // Sample a topic for (int tIndex = 0; tIndex < numTopics; tIndex++) { multiPros[tIndex] = (docTopicCount[tIndex] + alpha); for (int wIndex = 0; wIndex < docSize; wIndex++) { int word = document.get(wIndex); multiPros[tIndex] *= (lambda * expDotProductValues[tIndex][word] / sumExpValues[tIndex] + (1 - lambda) * (topicWordCountDMM[tIndex][word] + beta) / (sumTopicWordCountDMM[tIndex] + betaSum)); } } topic = FuncUtils.nextDiscrete(multiPros); docTopicCount[topic] += 1; for (int wIndex = 0; wIndex < docSize; wIndex++) { int word = document.get(wIndex); int subtopic = topic; if (lambda * expDotProductValues[topic][word] / sumExpValues[topic] > (1 - lambda) * (topicWordCountDMM[topic][word] + beta) / (sumTopicWordCountDMM[topic] + betaSum)) { topicWordCountLF[topic][word] += 1; sumTopicWordCountLF[topic] += 1; } else { topicWordCountDMM[topic][word] += 1; sumTopicWordCountDMM[topic] += 1; subtopic += numTopics; } // Update topic assignments topicAssignments.get(dIndex).set(wIndex, subtopic); } } } public void sampleSingleInitialIteration() { for (int dIndex = 0; dIndex < numDocuments; dIndex++) { List document = corpus.get(dIndex); int docSize = document.size(); int topic = topicAssignments.get(dIndex).get(0) % numTopics; docTopicCount[topic] = docTopicCount[topic] - 1; for (int wIndex = 0; wIndex < docSize; wIndex++) { int word = document.get(wIndex); int subtopic = topicAssignments.get(dIndex).get(wIndex); if (topic == subtopic) { topicWordCountLF[topic][word] -= 1; sumTopicWordCountLF[topic] -= 1; } else { topicWordCountDMM[topic][word] -= 1; sumTopicWordCountDMM[topic] -= 1; } } // Sample a topic for (int tIndex = 0; tIndex < numTopics; tIndex++) { multiPros[tIndex] = (docTopicCount[tIndex] + alpha); for (int wIndex = 0; wIndex < docSize; wIndex++) { int word = document.get(wIndex); multiPros[tIndex] *= (lambda * (topicWordCountLF[tIndex][word] + beta) / (sumTopicWordCountLF[tIndex] + betaSum) + (1 - lambda) * (topicWordCountDMM[tIndex][word] + beta) / (sumTopicWordCountDMM[tIndex] + betaSum)); } } topic = FuncUtils.nextDiscrete(multiPros); docTopicCount[topic] += 1; for (int wIndex = 0; wIndex < docSize; wIndex++) { int word = document.get(wIndex);// wordID int subtopic = topic; if (lambda * (topicWordCountLF[topic][word] + beta) / (sumTopicWordCountLF[topic] + betaSum) > (1 - lambda) * (topicWordCountDMM[topic][word] + beta) / (sumTopicWordCountDMM[topic] + betaSum)) { topicWordCountLF[topic][word] += 1; sumTopicWordCountLF[topic] += 1; } else { topicWordCountDMM[topic][word] += 1; sumTopicWordCountDMM[topic] += 1; subtopic += numTopics; } // Update topic assignments topicAssignments.get(dIndex).set(wIndex, subtopic); } } } public void writeParameters() throws IOException { BufferedWriter writer = new BufferedWriter(new FileWriter(folderPath + expName + ".paras")); writer.write("-model" + "\t" + "LFDMM"); writer.write("\n-corpus" + "\t" + corpusPath); writer.write("\n-vectors" + "\t" + vectorFilePath); writer.write("\n-ntopics" + "\t" + numTopics); writer.write("\n-alpha" + "\t" + alpha); writer.write("\n-beta" + "\t" + beta); writer.write("\n-lambda" + "\t" + lambda); writer.write("\n-initers" + "\t" + numInitIterations); writer.write("\n-niters" + "\t" + numIterations); writer.write("\n-twords" + "\t" + topWords); writer.write("\n-name" + "\t" + expName); if (tAssignsFilePath.length() > 0) writer.write("\n-initFile" + "\t" + tAssignsFilePath); if (savestep > 0) writer.write("\n-sstep" + "\t" + savestep); writer.close(); } public void writeDictionary() throws IOException { BufferedWriter writer = new BufferedWriter(new FileWriter(folderPath + expName + ".vocabulary")); for (String word : word2IdVocabulary.keySet()) { writer.write(word + " " + word2IdVocabulary.get(word) + "\n"); } writer.close(); } public void writeIDbasedCorpus() throws IOException { BufferedWriter writer = new BufferedWriter(new FileWriter(folderPath + expName + ".IDcorpus")); for (int dIndex = 0; dIndex < numDocuments; dIndex++) { int docSize = corpus.get(dIndex).size(); for (int wIndex = 0; wIndex < docSize; wIndex++) { writer.write(corpus.get(dIndex).get(wIndex) + " "); } writer.write("\n"); } writer.close(); } public void writeTopicAssignments() throws IOException { BufferedWriter writer = new BufferedWriter(new FileWriter(folderPath + expName + ".topicAssignments")); for (int dIndex = 0; dIndex < numDocuments; dIndex++) { int docSize = corpus.get(dIndex).size(); for (int wIndex = 0; wIndex < docSize; wIndex++) { writer.write(topicAssignments.get(dIndex).get(wIndex) + " "); } writer.write("\n"); } writer.close(); } public void writeTopicVectors() throws IOException { BufferedWriter writer = new BufferedWriter(new FileWriter(folderPath + expName + ".topicVectors")); for (int i = 0; i < numTopics; i++) { for (int j = 0; j < vectorSize; j++) writer.write(topicVectors[i][j] + " "); writer.write("\n"); } writer.close(); } public void writeTopTopicalWords() throws IOException { BufferedWriter writer = new BufferedWriter(new FileWriter(folderPath + expName + ".topWords")); for (int tIndex = 0; tIndex < numTopics; tIndex++) { writer.write("Topic" + new Integer(tIndex) + ":"); Map topicWordProbs = new TreeMap(); for (int wIndex = 0; wIndex < vocabularySize; wIndex++) { double pro = lambda * expDotProductValues[tIndex][wIndex] / sumExpValues[tIndex] + (1 - lambda) * (topicWordCountDMM[tIndex][wIndex] + beta) / (sumTopicWordCountDMM[tIndex] + betaSum); topicWordProbs.put(wIndex, pro); } topicWordProbs = FuncUtils.sortByValueDescending(topicWordProbs); Set mostLikelyWords = topicWordProbs.keySet(); int count = 0; for (Integer index : mostLikelyWords) { if (count < topWords) { writer.write(" " + id2WordVocabulary.get(index)); count += 1; } else { writer.write("\n\n"); break; } } } writer.close(); } public void writeTopicWordPros() throws IOException { BufferedWriter writer = new BufferedWriter(new FileWriter(folderPath + expName + ".phi")); for (int t = 0; t < numTopics; t++) { for (int w = 0; w < vocabularySize; w++) { double pro = lambda * expDotProductValues[t][w] / sumExpValues[t] + (1 - lambda) * (topicWordCountDMM[t][w] + beta) / (sumTopicWordCountDMM[t] + betaSum); writer.write(pro + " "); } writer.write("\n"); } writer.close(); } public void writeDocTopicPros() throws IOException { BufferedWriter writer = new BufferedWriter(new FileWriter(folderPath + expName + ".theta")); for (int i = 0; i < numDocuments; i++) { int docSize = corpus.get(i).size(); double sum = 0.0; for (int tIndex = 0; tIndex < numTopics; tIndex++) { multiPros[tIndex] = (docTopicCount[tIndex] + alpha); for (int wIndex = 0; wIndex < docSize; wIndex++) { int word = corpus.get(i).get(wIndex); multiPros[tIndex] *= (lambda * expDotProductValues[tIndex][word] / sumExpValues[tIndex] + (1 - lambda) * (topicWordCountDMM[tIndex][word] + beta) / (sumTopicWordCountDMM[tIndex] + betaSum)); } sum += multiPros[tIndex]; } for (int tIndex = 0; tIndex < numTopics; tIndex++) { writer.write((multiPros[tIndex] / sum) + " "); } writer.write("\n"); } writer.close(); } public void write() throws IOException { writeTopTopicalWords(); writeDocTopicPros(); writeTopicAssignments(); writeTopicWordPros(); } public static void main(String args[]) throws Exception { LFDMM lfdmm = new LFDMM("test/corpus.txt", "test/wordVectors.txt", 4, 0.1, 0.01, 0.6, 2000, 200, 20, "testLFDMM"); lfdmm.writeParameters(); lfdmm.inference(); } } ================================================ FILE: src/models/LFDMM_Inf.java ================================================ package models; import java.io.BufferedReader; import java.io.BufferedWriter; import java.io.FileReader; import java.io.FileWriter; import java.io.IOException; import java.util.ArrayList; import java.util.HashMap; import java.util.List; import java.util.Map; import java.util.Set; import java.util.TreeMap; import utility.FuncUtils; import utility.LBFGS; import utility.Parallel; import cc.mallet.optimize.InvalidOptimizableException; import cc.mallet.optimize.Optimizer; import cc.mallet.types.MatrixOps; import cc.mallet.util.Randoms; /** * Implementation of the LF-DMM latent feature topic model, using collapsed * Gibbs sampling, as described in: * * Dat Quoc Nguyen, Richard Billingsley, Lan Du and Mark Johnson. 2015. * Improving Topic Models with Latent Feature Word Representations. Transactions * of the Association for Computational Linguistics, vol. 3, pp. 299-313. * * Inference of topic distribution on unseen corpus * * @author Dat Quoc Nguyen */ public class LFDMM_Inf { public double alpha; // Hyper-parameter alpha public double beta; // Hyper-parameter alpha // public double alphaSum; // alpha * numTopics public double betaSum; // beta * vocabularySize public int numTopics; // Number of topics public int topWords; // Number of most probable words for each topic public double lambda; // Mixture weight value public int numInitIterations; public int numIterations; // Number of EM-style sampling iterations public List> corpus; // Word ID-based corpus public List> topicAssignments; // Topics assignments for words // in the corpus public int numDocuments; // Number of documents in the corpus public int numWordsInCorpus; // Number of words in the corpus public HashMap word2IdVocabulary; // Vocabulary to get ID // given a word public HashMap id2WordVocabulary; // Vocabulary to get word // given an ID public int vocabularySize; // The number of word types in the corpus // Number of documents assigned to a topic public int[] docTopicCount; // numTopics * vocabularySize matrix // Given a topic: number of times a word type generated from the topic by // the Dirichlet multinomial component public int[][] topicWordCountDMM; // Total number of words generated from each topic by the Dirichlet // multinomial component public int[] sumTopicWordCountDMM; // numTopics * vocabularySize matrix // Given a topic: number of times a word type generated from the topic by // the latent feature component public int[][] topicWordCountLF; // Total number of words generated from each topic by the latent feature // component public int[] sumTopicWordCountLF; // Double array used to sample a topic public double[] multiPros; // Path to the directory containing the corpus public String folderPath; // Path to the topic modeling corpus public String corpusPath; public String vectorFilePath; public double[][] wordVectors; // Vector representations for words public double[][] topicVectors;// Vector representations for topics public int vectorSize; // Number of vector dimensions public double[][] dotProductValues; public double[][] expDotProductValues; public double[] sumExpValues; // Partition function values public final double l2Regularizer = 0.01; // L2 regularizer value for // learning topic vectors public final double tolerance = 0.05; // Tolerance value for LBFGS // convergence public String expName = "LFDMMinf"; public String orgExpName = "LFDMMinf"; public int savestep = 0; public LFDMM_Inf(String pathToTrainingParasFile, String pathToUnseenCorpus, int inNumInitIterations, int inNumIterations, int inTopWords, String inExpName, int inSaveStep) throws Exception { HashMap paras = parseTrainingParasFile(pathToTrainingParasFile); if (!paras.get("-model").equals("LFDMM")) { throw new Exception("Wrong pre-trained model!!!"); } alpha = new Double(paras.get("-alpha")); beta = new Double(paras.get("-beta")); lambda = new Double(paras.get("-lambda")); numTopics = new Integer(paras.get("-ntopics")); numIterations = inNumIterations; numInitIterations = inNumInitIterations; topWords = inTopWords; savestep = inSaveStep; expName = inExpName; orgExpName = expName; vectorFilePath = paras.get("-vectors"); String trainingCorpus = paras.get("-corpus"); String trainingCorpusfolder = trainingCorpus.substring( 0, Math.max(trainingCorpus.lastIndexOf("/"), trainingCorpus.lastIndexOf("\\")) + 1); String topicAssignment4TrainFile = trainingCorpusfolder + paras.get("-name") + ".topicAssignments"; word2IdVocabulary = new HashMap(); id2WordVocabulary = new HashMap(); initializeWordCount(trainingCorpus, topicAssignment4TrainFile); corpusPath = pathToUnseenCorpus; folderPath = pathToUnseenCorpus.substring( 0, Math.max(pathToUnseenCorpus.lastIndexOf("/"), pathToUnseenCorpus.lastIndexOf("\\")) + 1); System.out.println("Reading unseen corpus: " + pathToUnseenCorpus); corpus = new ArrayList>(); numDocuments = 0; numWordsInCorpus = 0; BufferedReader br = null; try { br = new BufferedReader(new FileReader(pathToUnseenCorpus)); for (String doc; (doc = br.readLine()) != null;) { if (doc.trim().length() == 0) continue; String[] words = doc.trim().split("\\s+"); List document = new ArrayList(); for (String word : words) { if (word2IdVocabulary.containsKey(word)) { document.add(word2IdVocabulary.get(word)); } else { // Skip this unknown-word } } numDocuments++; numWordsInCorpus += document.size(); corpus.add(document); } } catch (Exception e) { e.printStackTrace(); } docTopicCount = new int[numTopics]; multiPros = new double[numTopics]; for (int i = 0; i < numTopics; i++) { multiPros[i] = 1.0 / numTopics; } // alphaSum = numTopics * alpha; betaSum = vocabularySize * beta; readWordVectorsFile(vectorFilePath); topicVectors = new double[numTopics][vectorSize]; dotProductValues = new double[numTopics][vocabularySize]; expDotProductValues = new double[numTopics][vocabularySize]; sumExpValues = new double[numTopics]; System.out.println("Corpus size: " + numDocuments + " docs, " + numWordsInCorpus + " words"); System.out.println("Vocabuary size: " + vocabularySize); System.out.println("Number of topics: " + numTopics); System.out.println("alpha: " + alpha); System.out.println("beta: " + beta); System.out.println("lambda: " + lambda); System.out.println("Number of initial sampling iterations: " + numInitIterations); System.out .println("Number of EM-style sampling iterations for the LF-DMM model: " + numIterations); System.out.println("Number of top topical words: " + topWords); initialize(); } private HashMap parseTrainingParasFile( String pathToTrainingParasFile) throws Exception { HashMap paras = new HashMap(); BufferedReader br = null; try { br = new BufferedReader(new FileReader(pathToTrainingParasFile)); for (String line; (line = br.readLine()) != null;) { if (line.trim().length() == 0) continue; String[] paraOptions = line.trim().split("\\s+"); paras.put(paraOptions[0], paraOptions[1]); } } catch (Exception e) { e.printStackTrace(); } return paras; } private void initializeWordCount(String pathToTrainingCorpus, String pathToTopicAssignmentFile) { System.out.println("Loading pre-trained model..."); List> trainCorpus = new ArrayList>(); BufferedReader br = null; try { int indexWord = -1; br = new BufferedReader(new FileReader(pathToTrainingCorpus)); for (String doc; (doc = br.readLine()) != null;) { if (doc.trim().length() == 0) continue; String[] words = doc.trim().split("\\s+"); List document = new ArrayList(); for (String word : words) { if (word2IdVocabulary.containsKey(word)) { document.add(word2IdVocabulary.get(word)); } else { indexWord += 1; word2IdVocabulary.put(word, indexWord); id2WordVocabulary.put(indexWord, word); document.add(indexWord); } } trainCorpus.add(document); } } catch (Exception e) { e.printStackTrace(); } vocabularySize = word2IdVocabulary.size(); topicWordCountDMM = new int[numTopics][vocabularySize]; sumTopicWordCountDMM = new int[numTopics]; topicWordCountLF = new int[numTopics][vocabularySize]; sumTopicWordCountLF = new int[numTopics]; try { br = new BufferedReader(new FileReader(pathToTopicAssignmentFile)); int docId = 0; for (String line; (line = br.readLine()) != null;) { String[] strTopics = line.trim().split("\\s+"); int topic = new Integer(strTopics[0]) % numTopics; for (int j = 0; j < strTopics.length; j++) { int wordId = trainCorpus.get(docId).get(j); int subtopic = new Integer(strTopics[j]); if (subtopic == topic) { topicWordCountLF[topic][wordId] += 1; sumTopicWordCountLF[topic] += 1; } else { topicWordCountDMM[topic][wordId] += 1; sumTopicWordCountDMM[topic] += 1; } } docId++; } } catch (Exception e) { e.printStackTrace(); } } public void readWordVectorsFile(String pathToWordVectorsFile) throws Exception { System.out.println("Reading word vectors from word-vectors file " + pathToWordVectorsFile + "..."); BufferedReader br = null; try { br = new BufferedReader(new FileReader(pathToWordVectorsFile)); String[] elements = br.readLine().trim().split("\\s+"); vectorSize = elements.length - 1; wordVectors = new double[vocabularySize][vectorSize]; String word = elements[0]; if (word2IdVocabulary.containsKey(word)) { for (int j = 0; j < vectorSize; j++) { wordVectors[word2IdVocabulary.get(word)][j] = new Double( elements[j + 1]); } } for (String line; (line = br.readLine()) != null;) { elements = line.trim().split("\\s+"); word = elements[0]; if (word2IdVocabulary.containsKey(word)) { for (int j = 0; j < vectorSize; j++) { wordVectors[word2IdVocabulary.get(word)][j] = new Double( elements[j + 1]); } } } } catch (Exception e) { e.printStackTrace(); } for (int i = 0; i < vocabularySize; i++) { if (MatrixOps.absNorm(wordVectors[i]) == 0.0) { System.out.println("The word \"" + id2WordVocabulary.get(i) + "\" doesn't have a corresponding vector!!!"); throw new Exception(); } } } public void initialize() throws IOException { System.out.println("Randomly initialzing topic assignments ..."); topicAssignments = new ArrayList>(); for (int docId = 0; docId < numDocuments; docId++) { List topics = new ArrayList(); int topic = FuncUtils.nextDiscrete(multiPros); docTopicCount[topic] += 1; int docSize = corpus.get(docId).size(); for (int j = 0; j < docSize; j++) { int wordId = corpus.get(docId).get(j); boolean component = new Randoms().nextBoolean(); int subtopic = topic; if (!component) { // Generated from the latent feature component topicWordCountLF[topic][wordId] += 1; sumTopicWordCountLF[topic] += 1; } else {// Generated from the Dirichlet multinomial component topicWordCountDMM[topic][wordId] += 1; sumTopicWordCountDMM[topic] += 1; subtopic = subtopic + numTopics; } topics.add(subtopic); } topicAssignments.add(topics); } } public void inference() throws IOException { System.out.println("Running Gibbs sampling inference: "); for (int iter = 1; iter <= numInitIterations; iter++) { System.out.println("\tInitial sampling iteration: " + (iter)); sampleSingleInitialIteration(); } for (int iter = 1; iter <= numIterations; iter++) { System.out.println("\tLFDMM sampling iteration: " + (iter)); optimizeTopicVectors(); sampleSingleIteration(); if ((savestep > 0) && (iter % savestep == 0) && (iter < numIterations)) { System.out.println("\t\tSaving the output from the " + iter + "^{th} sample"); expName = orgExpName + "-" + iter; write(); } } expName = orgExpName; writeParameters(); System.out.println("Writing output from the last sample ..."); write(); System.out.println("Sampling completed!"); } public void optimizeTopicVectors() { System.out.println("\t\tEstimating topic vectors ..."); sumExpValues = new double[numTopics]; dotProductValues = new double[numTopics][vocabularySize]; expDotProductValues = new double[numTopics][vocabularySize]; Parallel.loop(numTopics, new Parallel.LoopInt() { @Override public void compute(int topic) { int rate = 1; boolean check = true; while (check) { double l2Value = l2Regularizer * rate; try { TopicVectorOptimizer optimizer = new TopicVectorOptimizer( topicVectors[topic], topicWordCountLF[topic], wordVectors, l2Value); Optimizer gd = new LBFGS(optimizer, tolerance); gd.optimize(600); optimizer.getParameters(topicVectors[topic]); sumExpValues[topic] = optimizer .computePartitionFunction(dotProductValues[topic], expDotProductValues[topic]); check = false; if (sumExpValues[topic] == 0 || Double.isInfinite(sumExpValues[topic])) { double max = -1000000000.0; for (int index = 0; index < vocabularySize; index++) { if (dotProductValues[topic][index] > max) max = dotProductValues[topic][index]; } for (int index = 0; index < vocabularySize; index++) { expDotProductValues[topic][index] = Math .exp(dotProductValues[topic][index] - max); sumExpValues[topic] += expDotProductValues[topic][index]; } } } catch (InvalidOptimizableException e) { e.printStackTrace(); check = true; } rate = rate * 10; } } }); } public void sampleSingleIteration() { for (int dIndex = 0; dIndex < numDocuments; dIndex++) { List document = corpus.get(dIndex); int docSize = document.size(); int topic = topicAssignments.get(dIndex).get(0) % numTopics; docTopicCount[topic] = docTopicCount[topic] - 1; for (int wIndex = 0; wIndex < docSize; wIndex++) { int word = document.get(wIndex);// wordId int subtopic = topicAssignments.get(dIndex).get(wIndex); if (subtopic == topic) { topicWordCountLF[topic][word] -= 1; sumTopicWordCountLF[topic] -= 1; } else { topicWordCountDMM[topic][word] -= 1; sumTopicWordCountDMM[topic] -= 1; } } // Sample a topic for (int tIndex = 0; tIndex < numTopics; tIndex++) { multiPros[tIndex] = (docTopicCount[tIndex] + alpha); for (int wIndex = 0; wIndex < docSize; wIndex++) { int word = document.get(wIndex); multiPros[tIndex] *= (lambda * expDotProductValues[tIndex][word] / sumExpValues[tIndex] + (1 - lambda) * (topicWordCountDMM[tIndex][word] + beta) / (sumTopicWordCountDMM[tIndex] + betaSum)); } } topic = FuncUtils.nextDiscrete(multiPros); docTopicCount[topic] += 1; for (int wIndex = 0; wIndex < docSize; wIndex++) { int word = document.get(wIndex); int subtopic = topic; if (lambda * expDotProductValues[topic][word] / sumExpValues[topic] > (1 - lambda) * (topicWordCountDMM[topic][word] + beta) / (sumTopicWordCountDMM[topic] + betaSum)) { topicWordCountLF[topic][word] += 1; sumTopicWordCountLF[topic] += 1; } else { topicWordCountDMM[topic][word] += 1; sumTopicWordCountDMM[topic] += 1; subtopic += numTopics; } // Update topic assignments topicAssignments.get(dIndex).set(wIndex, subtopic); } } } public void sampleSingleInitialIteration() { for (int dIndex = 0; dIndex < numDocuments; dIndex++) { List document = corpus.get(dIndex); int docSize = document.size(); int topic = topicAssignments.get(dIndex).get(0) % numTopics; docTopicCount[topic] = docTopicCount[topic] - 1; for (int wIndex = 0; wIndex < docSize; wIndex++) { int word = document.get(wIndex); int subtopic = topicAssignments.get(dIndex).get(wIndex); if (topic == subtopic) { topicWordCountLF[topic][word] -= 1; sumTopicWordCountLF[topic] -= 1; } else { topicWordCountDMM[topic][word] -= 1; sumTopicWordCountDMM[topic] -= 1; } } // Sample a topic for (int tIndex = 0; tIndex < numTopics; tIndex++) { multiPros[tIndex] = (docTopicCount[tIndex] + alpha); for (int wIndex = 0; wIndex < docSize; wIndex++) { int word = document.get(wIndex); multiPros[tIndex] *= (lambda * (topicWordCountLF[tIndex][word] + beta) / (sumTopicWordCountLF[tIndex] + betaSum) + (1 - lambda) * (topicWordCountDMM[tIndex][word] + beta) / (sumTopicWordCountDMM[tIndex] + betaSum)); } } topic = FuncUtils.nextDiscrete(multiPros); docTopicCount[topic] += 1; for (int wIndex = 0; wIndex < docSize; wIndex++) { int word = document.get(wIndex);// wordID int subtopic = topic; if (lambda * (topicWordCountLF[topic][word] + beta) / (sumTopicWordCountLF[topic] + betaSum) > (1 - lambda) * (topicWordCountDMM[topic][word] + beta) / (sumTopicWordCountDMM[topic] + betaSum)) { topicWordCountLF[topic][word] += 1; sumTopicWordCountLF[topic] += 1; } else { topicWordCountDMM[topic][word] += 1; sumTopicWordCountDMM[topic] += 1; subtopic += numTopics; } // Update topic assignments topicAssignments.get(dIndex).set(wIndex, subtopic); } } } public void writeParameters() throws IOException { BufferedWriter writer = new BufferedWriter(new FileWriter(folderPath + expName + ".paras")); writer.write("-model" + "\t" + "LFDMM"); writer.write("\n-corpus" + "\t" + corpusPath); writer.write("\n-vectors" + "\t" + vectorFilePath); writer.write("\n-ntopics" + "\t" + numTopics); writer.write("\n-alpha" + "\t" + alpha); writer.write("\n-beta" + "\t" + beta); writer.write("\n-lambda" + "\t" + lambda); writer.write("\n-initers" + "\t" + numInitIterations); writer.write("\n-niters" + "\t" + numIterations); writer.write("\n-twords" + "\t" + topWords); writer.write("\n-name" + "\t" + expName); if (savestep > 0) writer.write("\n-sstep" + "\t" + savestep); writer.close(); } public void writeDictionary() throws IOException { BufferedWriter writer = new BufferedWriter(new FileWriter(folderPath + expName + ".vocabulary")); for (String word : word2IdVocabulary.keySet()) { writer.write(word + " " + word2IdVocabulary.get(word) + "\n"); } writer.close(); } public void writeIDbasedCorpus() throws IOException { BufferedWriter writer = new BufferedWriter(new FileWriter(folderPath + expName + ".IDcorpus")); for (int dIndex = 0; dIndex < numDocuments; dIndex++) { int docSize = corpus.get(dIndex).size(); for (int wIndex = 0; wIndex < docSize; wIndex++) { writer.write(corpus.get(dIndex).get(wIndex) + " "); } writer.write("\n"); } writer.close(); } public void writeTopicAssignments() throws IOException { BufferedWriter writer = new BufferedWriter(new FileWriter(folderPath + expName + ".topicAssignments")); for (int dIndex = 0; dIndex < numDocuments; dIndex++) { int docSize = corpus.get(dIndex).size(); for (int wIndex = 0; wIndex < docSize; wIndex++) { writer.write(topicAssignments.get(dIndex).get(wIndex) + " "); } writer.write("\n"); } writer.close(); } public void writeTopicVectors() throws IOException { BufferedWriter writer = new BufferedWriter(new FileWriter(folderPath + expName + ".topicVectors")); for (int i = 0; i < numTopics; i++) { for (int j = 0; j < vectorSize; j++) writer.write(topicVectors[i][j] + " "); writer.write("\n"); } writer.close(); } public void writeTopTopicalWords() throws IOException { BufferedWriter writer = new BufferedWriter(new FileWriter(folderPath + expName + ".topWords")); for (int tIndex = 0; tIndex < numTopics; tIndex++) { writer.write("Topic" + new Integer(tIndex) + ":"); Map topicWordProbs = new TreeMap(); for (int wIndex = 0; wIndex < vocabularySize; wIndex++) { double pro = lambda * expDotProductValues[tIndex][wIndex] / sumExpValues[tIndex] + (1 - lambda) * (topicWordCountDMM[tIndex][wIndex] + beta) / (sumTopicWordCountDMM[tIndex] + betaSum); topicWordProbs.put(wIndex, pro); } topicWordProbs = FuncUtils.sortByValueDescending(topicWordProbs); Set mostLikelyWords = topicWordProbs.keySet(); int count = 0; for (Integer index : mostLikelyWords) { if (count < topWords) { writer.write(" " + id2WordVocabulary.get(index)); count += 1; } else { writer.write("\n\n"); break; } } } writer.close(); } public void writeTopicWordPros() throws IOException { BufferedWriter writer = new BufferedWriter(new FileWriter(folderPath + expName + ".phi")); for (int t = 0; t < numTopics; t++) { for (int w = 0; w < vocabularySize; w++) { double pro = lambda * expDotProductValues[t][w] / sumExpValues[t] + (1 - lambda) * (topicWordCountDMM[t][w] + beta) / (sumTopicWordCountDMM[t] + betaSum); writer.write(pro + " "); } writer.write("\n"); } writer.close(); } public void writeDocTopicPros() throws IOException { BufferedWriter writer = new BufferedWriter(new FileWriter(folderPath + expName + ".theta")); for (int i = 0; i < numDocuments; i++) { int docSize = corpus.get(i).size(); double sum = 0.0; for (int tIndex = 0; tIndex < numTopics; tIndex++) { multiPros[tIndex] = (docTopicCount[tIndex] + alpha); for (int wIndex = 0; wIndex < docSize; wIndex++) { int word = corpus.get(i).get(wIndex); multiPros[tIndex] *= (lambda * expDotProductValues[tIndex][word] / sumExpValues[tIndex] + (1 - lambda) * (topicWordCountDMM[tIndex][word] + beta) / (sumTopicWordCountDMM[tIndex] + betaSum)); } sum += multiPros[tIndex]; } for (int tIndex = 0; tIndex < numTopics; tIndex++) { writer.write((multiPros[tIndex] / sum) + " "); } writer.write("\n"); } writer.close(); } public void write() throws IOException { writeTopTopicalWords(); writeDocTopicPros(); writeTopicAssignments(); writeTopicWordPros(); } public static void main(String args[]) throws Exception { LFDMM_Inf lfdmm = new LFDMM_Inf("test/testLFDMM.paras", "test/corpus_test.txt", 2000, 200, 20, "testLFDMMinf", 0); lfdmm.inference(); } } ================================================ FILE: src/models/LFLDA.java ================================================ package models; import java.io.BufferedReader; import java.io.BufferedWriter; import java.io.FileReader; import java.io.FileWriter; import java.io.IOException; import java.util.ArrayList; import java.util.HashMap; import java.util.List; import java.util.Map; import java.util.Set; import java.util.TreeMap; import utility.FuncUtils; import utility.LBFGS; import utility.Parallel; import cc.mallet.optimize.InvalidOptimizableException; import cc.mallet.optimize.Optimizer; import cc.mallet.types.MatrixOps; /** * Implementation of the LF-LDA latent feature topic model, using collapsed Gibbs sampling, as * described in: * * Dat Quoc Nguyen, Richard Billingsley, Lan Du and Mark Johnson. 2015. Improving Topic Models with * Latent Feature Word Representations. Transactions of the Association for Computational * Linguistics, vol. 3, pp. 299-313. * * @author Dat Quoc Nguyen */ public class LFLDA { public double alpha; // Hyper-parameter alpha public double beta; // Hyper-parameter alpha public double alphaSum; // alpha * numTopics public double betaSum; // beta * vocabularySize public int numTopics; // Number of topics public int topWords; // Number of most probable words for each topic public double lambda; // Mixture weight value public int numInitIterations; public int numIterations; // Number of EM-style sampling iterations public List> corpus; // Word ID-based corpus public List> topicAssignments; // Topics assignments for words // in the corpus public int numDocuments; // Number of documents in the corpus public int numWordsInCorpus; // Number of words in the corpus public HashMap word2IdVocabulary; // Vocabulary to get ID // given a word public HashMap id2WordVocabulary; // Vocabulary to get word // given an ID public int vocabularySize; // The number of word types in the corpus // numDocuments * numTopics matrix // Given a document: number of its words assigned to each topic public int[][] docTopicCount; // Number of words in every document public int[] sumDocTopicCount; // numTopics * vocabularySize matrix // Given a topic: number of times a word type generated from the topic by // the Dirichlet multinomial component public int[][] topicWordCountLDA; // Total number of words generated from each topic by the Dirichlet // multinomial component public int[] sumTopicWordCountLDA; // numTopics * vocabularySize matrix // Given a topic: number of times a word type generated from the topic by // the latent feature component public int[][] topicWordCountLF; // Total number of words generated from each topic by the latent feature // component public int[] sumTopicWordCountLF; // Double array used to sample a topic public double[] multiPros; // Path to the directory containing the corpus public String folderPath; // Path to the topic modeling corpus public String corpusPath; public String vectorFilePath; public double[][] wordVectors; // Vector representations for words public double[][] topicVectors;// Vector representations for topics public int vectorSize; // Number of vector dimensions public double[][] dotProductValues; public double[][] expDotProductValues; public double[] sumExpValues; // Partition function values public final double l2Regularizer = 0.01; // L2 regularizer value for learning topic vectors public final double tolerance = 0.05; // Tolerance value for LBFGS convergence public String expName = "LFLDA"; public String orgExpName = "LFLDA"; public String tAssignsFilePath = ""; public int savestep = 0; public LFLDA(String pathToCorpus, String pathToWordVectorsFile, int inNumTopics, double inAlpha, double inBeta, double inLambda, int inNumInitIterations, int inNumIterations, int inTopWords) throws Exception { this(pathToCorpus, pathToWordVectorsFile, inNumTopics, inAlpha, inBeta, inLambda, inNumInitIterations, inNumIterations, inTopWords, "LFLDA"); } public LFLDA(String pathToCorpus, String pathToWordVectorsFile, int inNumTopics, double inAlpha, double inBeta, double inLambda, int inNumInitIterations, int inNumIterations, int inTopWords, String inExpName) throws Exception { this(pathToCorpus, pathToWordVectorsFile, inNumTopics, inAlpha, inBeta, inLambda, inNumInitIterations, inNumIterations, inTopWords, inExpName, "", 0); } public LFLDA(String pathToCorpus, String pathToWordVectorsFile, int inNumTopics, double inAlpha, double inBeta, double inLambda, int inNumInitIterations, int inNumIterations, int inTopWords, String inExpName, String pathToTAfile) throws Exception { this(pathToCorpus, pathToWordVectorsFile, inNumTopics, inAlpha, inBeta, inLambda, inNumInitIterations, inNumIterations, inTopWords, inExpName, pathToTAfile, 0); } public LFLDA(String pathToCorpus, String pathToWordVectorsFile, int inNumTopics, double inAlpha, double inBeta, double inLambda, int inNumInitIterations, int inNumIterations, int inTopWords, String inExpName, int inSaveStep) throws Exception { this(pathToCorpus, pathToWordVectorsFile, inNumTopics, inAlpha, inBeta, inLambda, inNumInitIterations, inNumIterations, inTopWords, inExpName, "", inSaveStep); } public LFLDA(String pathToCorpus, String pathToWordVectorsFile, int inNumTopics, double inAlpha, double inBeta, double inLambda, int inNumInitIterations, int inNumIterations, int inTopWords, String inExpName, String pathToTAfile, int inSaveStep) throws Exception { alpha = inAlpha; beta = inBeta; lambda = inLambda; numTopics = inNumTopics; numIterations = inNumIterations; numInitIterations = inNumInitIterations; topWords = inTopWords; savestep = inSaveStep; expName = inExpName; orgExpName = expName; vectorFilePath = pathToWordVectorsFile; corpusPath = pathToCorpus; folderPath = pathToCorpus.substring(0, Math.max(pathToCorpus.lastIndexOf("/"), pathToCorpus.lastIndexOf("\\")) + 1); System.out.println("Reading topic modeling corpus: " + pathToCorpus); word2IdVocabulary = new HashMap(); id2WordVocabulary = new HashMap(); corpus = new ArrayList>(); numDocuments = 0; numWordsInCorpus = 0; BufferedReader br = null; try { int indexWord = -1; br = new BufferedReader(new FileReader(pathToCorpus)); for (String doc; (doc = br.readLine()) != null;) { if (doc.trim().length() == 0) continue; String[] words = doc.trim().split("\\s+"); List document = new ArrayList(); for (String word : words) { if (word2IdVocabulary.containsKey(word)) { document.add(word2IdVocabulary.get(word)); } else { indexWord += 1; word2IdVocabulary.put(word, indexWord); id2WordVocabulary.put(indexWord, word); document.add(indexWord); } } numDocuments++; numWordsInCorpus += document.size(); corpus.add(document); } } catch (Exception e) { e.printStackTrace(); } vocabularySize = word2IdVocabulary.size(); docTopicCount = new int[numDocuments][numTopics]; sumDocTopicCount = new int[numDocuments]; topicWordCountLDA = new int[numTopics][vocabularySize]; sumTopicWordCountLDA = new int[numTopics]; topicWordCountLF = new int[numTopics][vocabularySize]; sumTopicWordCountLF = new int[numTopics]; multiPros = new double[numTopics * 2]; for (int i = 0; i < numTopics * 2; i++) { multiPros[i] = 1.0 / numTopics; } alphaSum = numTopics * alpha; betaSum = vocabularySize * beta; readWordVectorsFile(vectorFilePath); topicVectors = new double[numTopics][vectorSize]; dotProductValues = new double[numTopics][vocabularySize]; expDotProductValues = new double[numTopics][vocabularySize]; sumExpValues = new double[numTopics]; System.out .println("Corpus size: " + numDocuments + " docs, " + numWordsInCorpus + " words"); System.out.println("Vocabuary size: " + vocabularySize); System.out.println("Number of topics: " + numTopics); System.out.println("alpha: " + alpha); System.out.println("beta: " + beta); System.out.println("lambda: " + lambda); System.out.println("Number of initial sampling iterations: " + numInitIterations); System.out.println("Number of EM-style sampling iterations for the LF-LDA model: " + numIterations); System.out.println("Number of top topical words: " + topWords); tAssignsFilePath = pathToTAfile; if (tAssignsFilePath.length() > 0) initialize(tAssignsFilePath); else initialize(); } public void readWordVectorsFile(String pathToWordVectorsFile) throws Exception { System.out.println("Reading word vectors from word-vectors file " + pathToWordVectorsFile + "..."); BufferedReader br = null; try { br = new BufferedReader(new FileReader(pathToWordVectorsFile)); String[] elements = br.readLine().trim().split("\\s+"); vectorSize = elements.length - 1; wordVectors = new double[vocabularySize][vectorSize]; String word = elements[0]; if (word2IdVocabulary.containsKey(word)) { for (int j = 0; j < vectorSize; j++) { wordVectors[word2IdVocabulary.get(word)][j] = new Double(elements[j + 1]); } } for (String line; (line = br.readLine()) != null;) { elements = line.trim().split("\\s+"); word = elements[0]; if (word2IdVocabulary.containsKey(word)) { for (int j = 0; j < vectorSize; j++) { wordVectors[word2IdVocabulary.get(word)][j] = new Double(elements[j + 1]); } } } } catch (Exception e) { e.printStackTrace(); } for (int i = 0; i < vocabularySize; i++) { if (MatrixOps.absNorm(wordVectors[i]) == 0.0) { System.out.println("The word \"" + id2WordVocabulary.get(i) + "\" doesn't have a corresponding vector!!!"); throw new Exception(); } } } public void initialize() throws IOException { System.out.println("Randomly initialzing topic assignments ..."); topicAssignments = new ArrayList>(); for (int docId = 0; docId < numDocuments; docId++) { List topics = new ArrayList(); int docSize = corpus.get(docId).size(); for (int j = 0; j < docSize; j++) { int wordId = corpus.get(docId).get(j); int subtopic = FuncUtils.nextDiscrete(multiPros); int topic = subtopic % numTopics; if (topic == subtopic) { // Generated from the latent feature component topicWordCountLF[topic][wordId] += 1; sumTopicWordCountLF[topic] += 1; } else {// Generated from the Dirichlet multinomial component topicWordCountLDA[topic][wordId] += 1; sumTopicWordCountLDA[topic] += 1; } docTopicCount[docId][topic] += 1; sumDocTopicCount[docId] += 1; topics.add(subtopic); } topicAssignments.add(topics); } } public void initialize(String pathToTopicAssignmentFile) { System.out.println("Reading topic-assignment file: " + pathToTopicAssignmentFile); topicAssignments = new ArrayList>(); BufferedReader br = null; try { br = new BufferedReader(new FileReader(pathToTopicAssignmentFile)); int docId = 0; int numWords = 0; for (String line; (line = br.readLine()) != null;) { String[] strTopics = line.trim().split("\\s+"); List topics = new ArrayList(); for (int j = 0; j < strTopics.length; j++) { int wordId = corpus.get(docId).get(j); int subtopic = new Integer(strTopics[j]); int topic = subtopic % numTopics; if (topic == subtopic) { // Generated from the latent feature component topicWordCountLF[topic][wordId] += 1; sumTopicWordCountLF[topic] += 1; } else {// Generated from the Dirichlet multinomial component topicWordCountLDA[topic][wordId] += 1; sumTopicWordCountLDA[topic] += 1; } docTopicCount[docId][topic] += 1; sumDocTopicCount[docId] += 1; topics.add(subtopic); numWords++; } topicAssignments.add(topics); docId++; } if ((docId != numDocuments) || (numWords != numWordsInCorpus)) { System.out .println("The topic modeling corpus and topic assignment file are not consistent!!!"); throw new Exception(); } } catch (Exception e) { e.printStackTrace(); } } public void inference() throws IOException { System.out.println("Running Gibbs sampling inference: "); for (int iter = 1; iter <= numInitIterations; iter++) { System.out.println("\tInitial sampling iteration: " + (iter)); sampleSingleInitialIteration(); } for (int iter = 1; iter <= numIterations; iter++) { System.out.println("\tLFLDA sampling iteration: " + (iter)); optimizeTopicVectors(); sampleSingleIteration(); if ((savestep > 0) && (iter % savestep == 0) && (iter < numIterations)) { System.out.println("\t\tSaving the output from the " + iter + "^{th} sample"); expName = orgExpName + "-" + iter; write(); } } expName = orgExpName; writeParameters(); System.out.println("Writing output from the last sample ..."); write(); System.out.println("Sampling completed!"); } public void optimizeTopicVectors() { System.out.println("\t\tEstimating topic vectors ..."); sumExpValues = new double[numTopics]; dotProductValues = new double[numTopics][vocabularySize]; expDotProductValues = new double[numTopics][vocabularySize]; Parallel.loop(numTopics, new Parallel.LoopInt() { @Override public void compute(int topic) { int rate = 1; boolean check = true; while (check) { double l2Value = l2Regularizer * rate; try { TopicVectorOptimizer optimizer = new TopicVectorOptimizer( topicVectors[topic], topicWordCountLF[topic], wordVectors, l2Value); Optimizer gd = new LBFGS(optimizer, tolerance); gd.optimize(600); optimizer.getParameters(topicVectors[topic]); sumExpValues[topic] = optimizer.computePartitionFunction( dotProductValues[topic], expDotProductValues[topic]); check = false; if (sumExpValues[topic] == 0 || Double.isInfinite(sumExpValues[topic])) { double max = -1000000000.0; for (int index = 0; index < vocabularySize; index++) { if (dotProductValues[topic][index] > max) max = dotProductValues[topic][index]; } for (int index = 0; index < vocabularySize; index++) { expDotProductValues[topic][index] = Math .exp(dotProductValues[topic][index] - max); sumExpValues[topic] += expDotProductValues[topic][index]; } } } catch (InvalidOptimizableException e) { e.printStackTrace(); check = true; } rate = rate * 10; } } }); } public void sampleSingleIteration() { for (int dIndex = 0; dIndex < numDocuments; dIndex++) { int docSize = corpus.get(dIndex).size(); for (int wIndex = 0; wIndex < docSize; wIndex++) { // Get current word int word = corpus.get(dIndex).get(wIndex);// wordID int subtopic = topicAssignments.get(dIndex).get(wIndex); int topic = subtopic % numTopics; docTopicCount[dIndex][topic] -= 1; if (subtopic == topic) { topicWordCountLF[topic][word] -= 1; sumTopicWordCountLF[topic] -= 1; } else { topicWordCountLDA[topic][word] -= 1; sumTopicWordCountLDA[topic] -= 1; } // Sample a pair of topic z and binary indicator variable s for (int tIndex = 0; tIndex < numTopics; tIndex++) { multiPros[tIndex] = (docTopicCount[dIndex][tIndex] + alpha) * lambda * expDotProductValues[tIndex][word] / sumExpValues[tIndex]; multiPros[tIndex + numTopics] = (docTopicCount[dIndex][tIndex] + alpha) * (1 - lambda) * (topicWordCountLDA[tIndex][word] + beta) / (sumTopicWordCountLDA[tIndex] + betaSum); } subtopic = FuncUtils.nextDiscrete(multiPros); topic = subtopic % numTopics; docTopicCount[dIndex][topic] += 1; if (subtopic == topic) { topicWordCountLF[topic][word] += 1; sumTopicWordCountLF[topic] += 1; } else { topicWordCountLDA[topic][word] += 1; sumTopicWordCountLDA[topic] += 1; } // Update topic assignments topicAssignments.get(dIndex).set(wIndex, subtopic); } } } public void sampleSingleInitialIteration() { for (int dIndex = 0; dIndex < numDocuments; dIndex++) { int docSize = corpus.get(dIndex).size(); for (int wIndex = 0; wIndex < docSize; wIndex++) { int word = corpus.get(dIndex).get(wIndex);// wordID int subtopic = topicAssignments.get(dIndex).get(wIndex); int topic = subtopic % numTopics; docTopicCount[dIndex][topic] -= 1; if (subtopic == topic) { // LF(w|t) + LDA(t|d) topicWordCountLF[topic][word] -= 1; sumTopicWordCountLF[topic] -= 1; } else { // LDA(w|t) + LDA(t|d) topicWordCountLDA[topic][word] -= 1; sumTopicWordCountLDA[topic] -= 1; } // Sample a pair of topic z and binary indicator variable s for (int tIndex = 0; tIndex < numTopics; tIndex++) { multiPros[tIndex] = (docTopicCount[dIndex][tIndex] + alpha) * lambda * (topicWordCountLF[tIndex][word] + beta) / (sumTopicWordCountLF[tIndex] + betaSum); multiPros[tIndex + numTopics] = (docTopicCount[dIndex][tIndex] + alpha) * (1 - lambda) * (topicWordCountLDA[tIndex][word] + beta) / (sumTopicWordCountLDA[tIndex] + betaSum); } subtopic = FuncUtils.nextDiscrete(multiPros); topic = subtopic % numTopics; docTopicCount[dIndex][topic] += 1; if (topic == subtopic) { topicWordCountLF[topic][word] += 1; sumTopicWordCountLF[topic] += 1; } else { topicWordCountLDA[topic][word] += 1; sumTopicWordCountLDA[topic] += 1; } // Update topic assignments topicAssignments.get(dIndex).set(wIndex, subtopic); } } } public void writeParameters() throws IOException { BufferedWriter writer = new BufferedWriter(new FileWriter(folderPath + expName + ".paras")); writer.write("-model" + "\t" + "LFLDA"); writer.write("\n-corpus" + "\t" + corpusPath); writer.write("\n-vectors" + "\t" + vectorFilePath); writer.write("\n-ntopics" + "\t" + numTopics); writer.write("\n-alpha" + "\t" + alpha); writer.write("\n-beta" + "\t" + beta); writer.write("\n-lambda" + "\t" + lambda); writer.write("\n-initers" + "\t" + numInitIterations); writer.write("\n-niters" + "\t" + numIterations); writer.write("\n-twords" + "\t" + topWords); writer.write("\n-name" + "\t" + expName); if (tAssignsFilePath.length() > 0) writer.write("\n-initFile" + "\t" + tAssignsFilePath); if (savestep > 0) writer.write("\n-sstep" + "\t" + savestep); writer.close(); } public void writeDictionary() throws IOException { BufferedWriter writer = new BufferedWriter(new FileWriter(folderPath + expName + ".vocabulary")); for (String word : word2IdVocabulary.keySet()) { writer.write(word + " " + word2IdVocabulary.get(word) + "\n"); } writer.close(); } public void writeIDbasedCorpus() throws IOException { BufferedWriter writer = new BufferedWriter(new FileWriter(folderPath + expName + ".IDcorpus")); for (int dIndex = 0; dIndex < numDocuments; dIndex++) { int docSize = corpus.get(dIndex).size(); for (int wIndex = 0; wIndex < docSize; wIndex++) { writer.write(corpus.get(dIndex).get(wIndex) + " "); } writer.write("\n"); } writer.close(); } public void writeTopicAssignments() throws IOException { BufferedWriter writer = new BufferedWriter(new FileWriter(folderPath + expName + ".topicAssignments")); for (int dIndex = 0; dIndex < numDocuments; dIndex++) { int docSize = corpus.get(dIndex).size(); for (int wIndex = 0; wIndex < docSize; wIndex++) { writer.write(topicAssignments.get(dIndex).get(wIndex) + " "); } writer.write("\n"); } writer.close(); } public void writeTopicVectors() throws IOException { BufferedWriter writer = new BufferedWriter(new FileWriter(folderPath + expName + ".topicVectors")); for (int i = 0; i < numTopics; i++) { for (int j = 0; j < vectorSize; j++) writer.write(topicVectors[i][j] + " "); writer.write("\n"); } writer.close(); } public void writeTopTopicalWords() throws IOException { BufferedWriter writer = new BufferedWriter(new FileWriter(folderPath + expName + ".topWords")); for (int tIndex = 0; tIndex < numTopics; tIndex++) { writer.write("Topic" + new Integer(tIndex) + ":"); Map topicWordProbs = new TreeMap(); for (int wIndex = 0; wIndex < vocabularySize; wIndex++) { double pro = lambda * expDotProductValues[tIndex][wIndex] / sumExpValues[tIndex] + (1 - lambda) * (topicWordCountLDA[tIndex][wIndex] + beta) / (sumTopicWordCountLDA[tIndex] + betaSum); topicWordProbs.put(wIndex, pro); } topicWordProbs = FuncUtils.sortByValueDescending(topicWordProbs); Set mostLikelyWords = topicWordProbs.keySet(); int count = 0; for (Integer index : mostLikelyWords) { if (count < topWords) { writer.write(" " + id2WordVocabulary.get(index)); count += 1; } else { writer.write("\n\n"); break; } } } writer.close(); } public void writeTopicWordPros() throws IOException { BufferedWriter writer = new BufferedWriter(new FileWriter(folderPath + expName + ".phi")); for (int t = 0; t < numTopics; t++) { for (int w = 0; w < vocabularySize; w++) { double pro = lambda * expDotProductValues[t][w] / sumExpValues[t] + (1 - lambda) * (topicWordCountLDA[t][w] + beta) / (sumTopicWordCountLDA[t] + betaSum); writer.write(pro + " "); } writer.write("\n"); } writer.close(); } public void writeDocTopicPros() throws IOException { BufferedWriter writer = new BufferedWriter(new FileWriter(folderPath + expName + ".theta")); for (int i = 0; i < numDocuments; i++) { for (int j = 0; j < numTopics; j++) { double pro = (docTopicCount[i][j] + alpha) / (sumDocTopicCount[i] + alphaSum); writer.write(pro + " "); } writer.write("\n"); } writer.close(); } public void write() throws IOException { writeTopTopicalWords(); writeDocTopicPros(); writeTopicAssignments(); writeTopicWordPros(); } public static void main(String args[]) throws Exception { LFLDA lflda = new LFLDA("test/corpus.txt", "test/wordVectors.txt", 4, 0.1, 0.01, 0.6, 2000, 200, 20, "testLFLDA"); lflda.writeParameters(); lflda.inference(); } } ================================================ FILE: src/models/LFLDA_Inf.java ================================================ package models; import java.io.BufferedReader; import java.io.BufferedWriter; import java.io.FileReader; import java.io.FileWriter; import java.io.IOException; import java.util.ArrayList; import java.util.HashMap; import java.util.List; import java.util.Map; import java.util.Set; import java.util.TreeMap; import utility.FuncUtils; import utility.LBFGS; import utility.Parallel; import cc.mallet.optimize.InvalidOptimizableException; import cc.mallet.optimize.Optimizer; import cc.mallet.types.MatrixOps; /** * Implementation of the LF-LDA latent feature topic model, using collapsed * Gibbs sampling, as described in: * * Dat Quoc Nguyen, Richard Billingsley, Lan Du and Mark Johnson. 2015. * Improving Topic Models with Latent Feature Word Representations. Transactions * of the Association for Computational Linguistics, vol. 3, pp. 299-313. * * Inference of topic distribution on unseen corpus * * @author Dat Quoc Nguyen */ public class LFLDA_Inf { public double alpha; // Hyper-parameter alpha public double beta; // Hyper-parameter alpha public double alphaSum; // alpha * numTopics public double betaSum; // beta * vocabularySize public int numTopics; // Number of topics public int topWords; // Number of most probable words for each topic public double lambda; // Mixture weight value public int numInitIterations; public int numIterations; // Number of EM-style sampling iterations public List> corpus; // Word ID-based corpus public List> topicAssignments; // Topics assignments for words // in the corpus public int numDocuments; // Number of documents in the corpus public int numWordsInCorpus; // Number of words in the corpus public HashMap word2IdVocabulary; // Vocabulary to get ID // given a word public HashMap id2WordVocabulary; // Vocabulary to get word // given an ID public int vocabularySize; // The number of word types in the corpus // numDocuments * numTopics matrix // Given a document: number of its words assigned to each topic public int[][] docTopicCount; // Number of words in every document public int[] sumDocTopicCount; // numTopics * vocabularySize matrix // Given a topic: number of times a word type generated from the topic by // the Dirichlet multinomial component public int[][] topicWordCountLDA; // Total number of words generated from each topic by the Dirichlet // multinomial component public int[] sumTopicWordCountLDA; // numTopics * vocabularySize matrix // Given a topic: number of times a word type generated from the topic by // the latent feature component public int[][] topicWordCountLF; // Total number of words generated from each topic by the latent feature // component public int[] sumTopicWordCountLF; // Double array used to sample a topic public double[] multiPros; // Path to the directory containing the corpus public String folderPath; // Path to the topic modeling corpus public String corpusPath; public String vectorFilePath; public double[][] wordVectors; // Vector representations for words public double[][] topicVectors;// Vector representations for topics public int vectorSize; // Number of vector dimensions public double[][] dotProductValues; public double[][] expDotProductValues; public double[] sumExpValues; // Partition function values public final double l2Regularizer = 0.01; // L2 regularizer value for // learning topic vectors public final double tolerance = 0.05; // Tolerance value for LBFGS // convergence public String expName = "LFLDAinf"; public String orgExpName = "LFLDAinf"; public int savestep = 0; public LFLDA_Inf(String pathToTrainingParasFile, String pathToUnseenCorpus, int inNumInitIterations, int inNumIterations, int inTopWords, String inExpName, int inSaveStep) throws Exception { HashMap paras = parseTrainingParasFile(pathToTrainingParasFile); if (!paras.get("-model").equals("LFLDA")) { throw new Exception("Wrong pre-trained model!!!"); } alpha = new Double(paras.get("-alpha")); beta = new Double(paras.get("-beta")); lambda = new Double(paras.get("-lambda")); numTopics = new Integer(paras.get("-ntopics")); numIterations = inNumIterations; numInitIterations = inNumInitIterations; topWords = inTopWords; savestep = inSaveStep; expName = inExpName; orgExpName = expName; vectorFilePath = paras.get("-vectors"); String trainingCorpus = paras.get("-corpus"); String trainingCorpusfolder = trainingCorpus.substring( 0, Math.max(trainingCorpus.lastIndexOf("/"), trainingCorpus.lastIndexOf("\\")) + 1); String topicAssignment4TrainFile = trainingCorpusfolder + paras.get("-name") + ".topicAssignments"; word2IdVocabulary = new HashMap(); id2WordVocabulary = new HashMap(); initializeWordCount(trainingCorpus, topicAssignment4TrainFile); corpusPath = pathToUnseenCorpus; folderPath = pathToUnseenCorpus.substring( 0, Math.max(pathToUnseenCorpus.lastIndexOf("/"), pathToUnseenCorpus.lastIndexOf("\\")) + 1); System.out.println("Reading unseen corpus: " + pathToUnseenCorpus); corpus = new ArrayList>(); numDocuments = 0; numWordsInCorpus = 0; BufferedReader br = null; try { br = new BufferedReader(new FileReader(pathToUnseenCorpus)); for (String doc; (doc = br.readLine()) != null;) { if (doc.trim().length() == 0) continue; String[] words = doc.trim().split("\\s+"); List document = new ArrayList(); for (String word : words) { if (word2IdVocabulary.containsKey(word)) { document.add(word2IdVocabulary.get(word)); } else { // Skip this unknown-word } } numDocuments++; numWordsInCorpus += document.size(); corpus.add(document); } } catch (Exception e) { e.printStackTrace(); } docTopicCount = new int[numDocuments][numTopics]; sumDocTopicCount = new int[numDocuments]; multiPros = new double[numTopics * 2]; for (int i = 0; i < numTopics * 2; i++) { multiPros[i] = 1.0 / numTopics; } alphaSum = numTopics * alpha; betaSum = vocabularySize * beta; readWordVectorsFile(vectorFilePath); topicVectors = new double[numTopics][vectorSize]; dotProductValues = new double[numTopics][vocabularySize]; expDotProductValues = new double[numTopics][vocabularySize]; sumExpValues = new double[numTopics]; System.out.println("Corpus size: " + numDocuments + " docs, " + numWordsInCorpus + " words"); System.out.println("Vocabuary size: " + vocabularySize); System.out.println("Number of topics: " + numTopics); System.out.println("alpha: " + alpha); System.out.println("beta: " + beta); System.out.println("lambda: " + lambda); System.out.println("Number of initial sampling iterations: " + numInitIterations); System.out .println("Number of EM-style sampling iterations for the LF-LDA model: " + numIterations); System.out.println("Number of top topical words: " + topWords); initialize(); } private HashMap parseTrainingParasFile( String pathToTrainingParasFile) throws Exception { HashMap paras = new HashMap(); BufferedReader br = null; try { br = new BufferedReader(new FileReader(pathToTrainingParasFile)); for (String line; (line = br.readLine()) != null;) { if (line.trim().length() == 0) continue; String[] paraOptions = line.trim().split("\\s+"); paras.put(paraOptions[0], paraOptions[1]); } } catch (Exception e) { e.printStackTrace(); } return paras; } private void initializeWordCount(String pathToTrainingCorpus, String pathToTopicAssignmentFile) { System.out.println("Loading pre-trained model..."); List> trainCorpus = new ArrayList>(); BufferedReader br = null; try { int indexWord = -1; br = new BufferedReader(new FileReader(pathToTrainingCorpus)); for (String doc; (doc = br.readLine()) != null;) { if (doc.trim().length() == 0) continue; String[] words = doc.trim().split("\\s+"); List document = new ArrayList(); for (String word : words) { if (word2IdVocabulary.containsKey(word)) { document.add(word2IdVocabulary.get(word)); } else { indexWord += 1; word2IdVocabulary.put(word, indexWord); id2WordVocabulary.put(indexWord, word); document.add(indexWord); } } trainCorpus.add(document); } } catch (Exception e) { e.printStackTrace(); } vocabularySize = word2IdVocabulary.size(); topicWordCountLDA = new int[numTopics][vocabularySize]; sumTopicWordCountLDA = new int[numTopics]; topicWordCountLF = new int[numTopics][vocabularySize]; sumTopicWordCountLF = new int[numTopics]; try { br = new BufferedReader(new FileReader(pathToTopicAssignmentFile)); int docId = 0; for (String line; (line = br.readLine()) != null;) { String[] strTopics = line.trim().split("\\s+"); for (int j = 0; j < strTopics.length; j++) { int wordId = trainCorpus.get(docId).get(j); int subtopic = new Integer(strTopics[j]); int topic = subtopic % numTopics; if (topic == subtopic) { // Generated from the latent // feature component topicWordCountLF[topic][wordId] += 1; sumTopicWordCountLF[topic] += 1; } else {// Generated from the Dirichlet multinomial component topicWordCountLDA[topic][wordId] += 1; sumTopicWordCountLDA[topic] += 1; } } docId++; } } catch (Exception e) { e.printStackTrace(); } } public void readWordVectorsFile(String pathToWordVectorsFile) throws Exception { System.out.println("Reading word vectors from word-vectors file " + pathToWordVectorsFile + "..."); BufferedReader br = null; try { br = new BufferedReader(new FileReader(pathToWordVectorsFile)); String[] elements = br.readLine().trim().split("\\s+"); vectorSize = elements.length - 1; wordVectors = new double[vocabularySize][vectorSize]; String word = elements[0]; if (word2IdVocabulary.containsKey(word)) { for (int j = 0; j < vectorSize; j++) { wordVectors[word2IdVocabulary.get(word)][j] = new Double( elements[j + 1]); } } for (String line; (line = br.readLine()) != null;) { elements = line.trim().split("\\s+"); word = elements[0]; if (word2IdVocabulary.containsKey(word)) { for (int j = 0; j < vectorSize; j++) { wordVectors[word2IdVocabulary.get(word)][j] = new Double( elements[j + 1]); } } } } catch (Exception e) { e.printStackTrace(); } for (int i = 0; i < vocabularySize; i++) { if (MatrixOps.absNorm(wordVectors[i]) == 0.0) { System.out.println("The word \"" + id2WordVocabulary.get(i) + "\" doesn't have a corresponding vector!!!"); throw new Exception(); } } } public void initialize() throws IOException { System.out.println("Randomly initialzing topic assignments ..."); topicAssignments = new ArrayList>(); for (int docId = 0; docId < numDocuments; docId++) { List topics = new ArrayList(); int docSize = corpus.get(docId).size(); for (int j = 0; j < docSize; j++) { int wordId = corpus.get(docId).get(j); int subtopic = FuncUtils.nextDiscrete(multiPros); int topic = subtopic % numTopics; if (topic == subtopic) { // Generated from the latent feature // component topicWordCountLF[topic][wordId] += 1; sumTopicWordCountLF[topic] += 1; } else {// Generated from the Dirichlet multinomial component topicWordCountLDA[topic][wordId] += 1; sumTopicWordCountLDA[topic] += 1; } docTopicCount[docId][topic] += 1; sumDocTopicCount[docId] += 1; topics.add(subtopic); } topicAssignments.add(topics); } } public void inference() throws IOException { System.out.println("Running Gibbs sampling inference: "); for (int iter = 1; iter <= numInitIterations; iter++) { System.out.println("\tInitial sampling iteration: " + (iter)); sampleSingleInitialIteration(); } for (int iter = 1; iter <= numIterations; iter++) { System.out.println("\tLFLDA sampling iteration: " + (iter)); optimizeTopicVectors(); sampleSingleIteration(); if ((savestep > 0) && (iter % savestep == 0) && (iter < numIterations)) { System.out.println("\t\tSaving the output from the " + iter + "^{th} sample"); expName = orgExpName + "-" + iter; write(); } } expName = orgExpName; writeParameters(); System.out.println("Writing output from the last sample ..."); write(); System.out.println("Sampling completed!"); } public void optimizeTopicVectors() { System.out.println("\t\tEstimating topic vectors ..."); sumExpValues = new double[numTopics]; dotProductValues = new double[numTopics][vocabularySize]; expDotProductValues = new double[numTopics][vocabularySize]; Parallel.loop(numTopics, new Parallel.LoopInt() { @Override public void compute(int topic) { int rate = 1; boolean check = true; while (check) { double l2Value = l2Regularizer * rate; try { TopicVectorOptimizer optimizer = new TopicVectorOptimizer( topicVectors[topic], topicWordCountLF[topic], wordVectors, l2Value); Optimizer gd = new LBFGS(optimizer, tolerance); gd.optimize(600); optimizer.getParameters(topicVectors[topic]); sumExpValues[topic] = optimizer .computePartitionFunction(dotProductValues[topic], expDotProductValues[topic]); check = false; if (sumExpValues[topic] == 0 || Double.isInfinite(sumExpValues[topic])) { double max = -1000000000.0; for (int index = 0; index < vocabularySize; index++) { if (dotProductValues[topic][index] > max) max = dotProductValues[topic][index]; } for (int index = 0; index < vocabularySize; index++) { expDotProductValues[topic][index] = Math .exp(dotProductValues[topic][index] - max); sumExpValues[topic] += expDotProductValues[topic][index]; } } } catch (InvalidOptimizableException e) { e.printStackTrace(); check = true; } rate = rate * 10; } } }); } public void sampleSingleIteration() { for (int dIndex = 0; dIndex < numDocuments; dIndex++) { int docSize = corpus.get(dIndex).size(); for (int wIndex = 0; wIndex < docSize; wIndex++) { // Get current word int word = corpus.get(dIndex).get(wIndex);// wordID int subtopic = topicAssignments.get(dIndex).get(wIndex); int topic = subtopic % numTopics; docTopicCount[dIndex][topic] -= 1; if (subtopic == topic) { topicWordCountLF[topic][word] -= 1; sumTopicWordCountLF[topic] -= 1; } else { topicWordCountLDA[topic][word] -= 1; sumTopicWordCountLDA[topic] -= 1; } // Sample a pair of topic z and binary indicator variable s for (int tIndex = 0; tIndex < numTopics; tIndex++) { multiPros[tIndex] = (docTopicCount[dIndex][tIndex] + alpha) * lambda * expDotProductValues[tIndex][word] / sumExpValues[tIndex]; multiPros[tIndex + numTopics] = (docTopicCount[dIndex][tIndex] + alpha) * (1 - lambda) * (topicWordCountLDA[tIndex][word] + beta) / (sumTopicWordCountLDA[tIndex] + betaSum); } subtopic = FuncUtils.nextDiscrete(multiPros); topic = subtopic % numTopics; docTopicCount[dIndex][topic] += 1; if (subtopic == topic) { topicWordCountLF[topic][word] += 1; sumTopicWordCountLF[topic] += 1; } else { topicWordCountLDA[topic][word] += 1; sumTopicWordCountLDA[topic] += 1; } // Update topic assignments topicAssignments.get(dIndex).set(wIndex, subtopic); } } } public void sampleSingleInitialIteration() { for (int dIndex = 0; dIndex < numDocuments; dIndex++) { int docSize = corpus.get(dIndex).size(); for (int wIndex = 0; wIndex < docSize; wIndex++) { int word = corpus.get(dIndex).get(wIndex);// wordID int subtopic = topicAssignments.get(dIndex).get(wIndex); int topic = subtopic % numTopics; docTopicCount[dIndex][topic] -= 1; if (subtopic == topic) { // LF(w|t) + LDA(t|d) topicWordCountLF[topic][word] -= 1; sumTopicWordCountLF[topic] -= 1; } else { // LDA(w|t) + LDA(t|d) topicWordCountLDA[topic][word] -= 1; sumTopicWordCountLDA[topic] -= 1; } // Sample a pair of topic z and binary indicator variable s for (int tIndex = 0; tIndex < numTopics; tIndex++) { multiPros[tIndex] = (docTopicCount[dIndex][tIndex] + alpha) * lambda * (topicWordCountLF[tIndex][word] + beta) / (sumTopicWordCountLF[tIndex] + betaSum); multiPros[tIndex + numTopics] = (docTopicCount[dIndex][tIndex] + alpha) * (1 - lambda) * (topicWordCountLDA[tIndex][word] + beta) / (sumTopicWordCountLDA[tIndex] + betaSum); } subtopic = FuncUtils.nextDiscrete(multiPros); topic = subtopic % numTopics; docTopicCount[dIndex][topic] += 1; if (topic == subtopic) { topicWordCountLF[topic][word] += 1; sumTopicWordCountLF[topic] += 1; } else { topicWordCountLDA[topic][word] += 1; sumTopicWordCountLDA[topic] += 1; } // Update topic assignments topicAssignments.get(dIndex).set(wIndex, subtopic); } } } public void writeParameters() throws IOException { BufferedWriter writer = new BufferedWriter(new FileWriter(folderPath + expName + ".paras")); writer.write("-model" + "\t" + "LFLDA"); writer.write("\n-corpus" + "\t" + corpusPath); writer.write("\n-vectors" + "\t" + vectorFilePath); writer.write("\n-ntopics" + "\t" + numTopics); writer.write("\n-alpha" + "\t" + alpha); writer.write("\n-beta" + "\t" + beta); writer.write("\n-lambda" + "\t" + lambda); writer.write("\n-initers" + "\t" + numInitIterations); writer.write("\n-niters" + "\t" + numIterations); writer.write("\n-twords" + "\t" + topWords); writer.write("\n-name" + "\t" + expName); if (savestep > 0) writer.write("\n-sstep" + "\t" + savestep); writer.close(); } public void writeDictionary() throws IOException { BufferedWriter writer = new BufferedWriter(new FileWriter(folderPath + expName + ".vocabulary")); for (String word : word2IdVocabulary.keySet()) { writer.write(word + " " + word2IdVocabulary.get(word) + "\n"); } writer.close(); } public void writeIDbasedCorpus() throws IOException { BufferedWriter writer = new BufferedWriter(new FileWriter(folderPath + expName + ".IDcorpus")); for (int dIndex = 0; dIndex < numDocuments; dIndex++) { int docSize = corpus.get(dIndex).size(); for (int wIndex = 0; wIndex < docSize; wIndex++) { writer.write(corpus.get(dIndex).get(wIndex) + " "); } writer.write("\n"); } writer.close(); } public void writeTopicAssignments() throws IOException { BufferedWriter writer = new BufferedWriter(new FileWriter(folderPath + expName + ".topicAssignments")); for (int dIndex = 0; dIndex < numDocuments; dIndex++) { int docSize = corpus.get(dIndex).size(); for (int wIndex = 0; wIndex < docSize; wIndex++) { writer.write(topicAssignments.get(dIndex).get(wIndex) + " "); } writer.write("\n"); } writer.close(); } public void writeTopicVectors() throws IOException { BufferedWriter writer = new BufferedWriter(new FileWriter(folderPath + expName + ".topicVectors")); for (int i = 0; i < numTopics; i++) { for (int j = 0; j < vectorSize; j++) writer.write(topicVectors[i][j] + " "); writer.write("\n"); } writer.close(); } public void writeTopTopicalWords() throws IOException { BufferedWriter writer = new BufferedWriter(new FileWriter(folderPath + expName + ".topWords")); for (int tIndex = 0; tIndex < numTopics; tIndex++) { writer.write("Topic" + new Integer(tIndex) + ":"); Map topicWordProbs = new TreeMap(); for (int wIndex = 0; wIndex < vocabularySize; wIndex++) { double pro = lambda * expDotProductValues[tIndex][wIndex] / sumExpValues[tIndex] + (1 - lambda) * (topicWordCountLDA[tIndex][wIndex] + beta) / (sumTopicWordCountLDA[tIndex] + betaSum); topicWordProbs.put(wIndex, pro); } topicWordProbs = FuncUtils.sortByValueDescending(topicWordProbs); Set mostLikelyWords = topicWordProbs.keySet(); int count = 0; for (Integer index : mostLikelyWords) { if (count < topWords) { writer.write(" " + id2WordVocabulary.get(index)); count += 1; } else { writer.write("\n\n"); break; } } } writer.close(); } public void writeTopicWordPros() throws IOException { BufferedWriter writer = new BufferedWriter(new FileWriter(folderPath + expName + ".phi")); for (int t = 0; t < numTopics; t++) { for (int w = 0; w < vocabularySize; w++) { double pro = lambda * expDotProductValues[t][w] / sumExpValues[t] + (1 - lambda) * (topicWordCountLDA[t][w] + beta) / (sumTopicWordCountLDA[t] + betaSum); writer.write(pro + " "); } writer.write("\n"); } writer.close(); } public void writeDocTopicPros() throws IOException { BufferedWriter writer = new BufferedWriter(new FileWriter(folderPath + expName + ".theta")); for (int i = 0; i < numDocuments; i++) { for (int j = 0; j < numTopics; j++) { double pro = (docTopicCount[i][j] + alpha) / (sumDocTopicCount[i] + alphaSum); writer.write(pro + " "); } writer.write("\n"); } writer.close(); } public void write() throws IOException { writeTopTopicalWords(); writeDocTopicPros(); writeTopicAssignments(); writeTopicWordPros(); } public static void main(String args[]) throws Exception { LFLDA_Inf lflda = new LFLDA_Inf("test/testLFLDA.paras", "test/corpus_test.txt", 2000, 200, 20, "testLFLDAinf", 0); lflda.inference(); } } ================================================ FILE: src/models/TopicVectorOptimizer.java ================================================ package models; import cc.mallet.optimize.Optimizable; import cc.mallet.types.MatrixOps; /** * Implementation of the MAP estimation for learning topic vectors, as described * in section 3.5 in: * * Dat Quoc Nguyen, Richard Billingsley, Lan Du and Mark Johnson. 2015. * Improving Topic Models with Latent Feature Word Representations. Transactions * of the Association for Computational Linguistics, vol. 3, pp. 299-313. * * @author Dat Quoc Nguyen */ public class TopicVectorOptimizer implements Optimizable.ByGradientValue { // Number of times a word type assigned to the topic int[] wordCount; int totalCount; // Total number of words assigned to the topic int vocaSize; // Size of the vocabulary // wordCount.length = wordVectors.length = vocaSize double[][] wordVectors;// Vector representations for words double[] topicVector;// Vector representation for a topic int vectorSize; // vectorSize = topicVector.length // For each i_{th} element of topic vector, compute: // sum_w wordCount[w] * wordVectors[w][i] double[] expectedCountValues; double l2Constant; // L2 regularizer for learning topic vectors double[] dotProductValues; double[] expDotProductValues; public TopicVectorOptimizer(double[] inTopicVector, int[] inWordCount, double[][] inWordVectors, double inL2Constant) { vocaSize = inWordCount.length; vectorSize = inWordVectors[0].length; l2Constant = inL2Constant; topicVector = new double[vectorSize]; System .arraycopy(inTopicVector, 0, topicVector, 0, inTopicVector.length); wordCount = new int[vocaSize]; System.arraycopy(inWordCount, 0, wordCount, 0, vocaSize); wordVectors = new double[vocaSize][vectorSize]; for (int w = 0; w < vocaSize; w++) System .arraycopy(inWordVectors[w], 0, wordVectors[w], 0, vectorSize); totalCount = 0; for (int w = 0; w < vocaSize; w++) { totalCount += wordCount[w]; } expectedCountValues = new double[vectorSize]; for (int i = 0; i < vectorSize; i++) { for (int w = 0; w < vocaSize; w++) { expectedCountValues[i] += wordCount[w] * wordVectors[w][i]; } } dotProductValues = new double[vocaSize]; expDotProductValues = new double[vocaSize]; } @Override public int getNumParameters() { return vectorSize; } @Override public void getParameters(double[] buffer) { for (int i = 0; i < vectorSize; i++) buffer[i] = topicVector[i]; } @Override public double getParameter(int index) { return topicVector[index]; } @Override public void setParameters(double[] params) { for (int i = 0; i < params.length; i++) topicVector[i] = params[i]; } @Override public void setParameter(int index, double value) { topicVector[index] = value; } @Override public void getValueGradient(double[] buffer) { double partitionFuncValue = computePartitionFunction(dotProductValues, expDotProductValues); for (int i = 0; i < vectorSize; i++) { buffer[i] = 0.0; double expectationValue = 0.0; for (int w = 0; w < vocaSize; w++) { expectationValue += wordVectors[w][i] * expDotProductValues[w]; } expectationValue = expectationValue / partitionFuncValue; buffer[i] = expectedCountValues[i] - totalCount * expectationValue - 2 * l2Constant * topicVector[i]; } } @Override public double getValue() { double logPartitionFuncValue = Math.log(computePartitionFunction( dotProductValues, expDotProductValues)); double value = 0.0; for (int w = 0; w < vocaSize; w++) { if (wordCount[w] == 0) continue; value += wordCount[w] * dotProductValues[w]; } value = value - totalCount * logPartitionFuncValue - l2Constant * MatrixOps.twoNormSquared(topicVector); return value; } // Compute the partition function public double computePartitionFunction(double[] elements1, double[] elements2) { double value = 0.0; for (int w = 0; w < vocaSize; w++) { elements1[w] = MatrixOps.dotProduct(wordVectors[w], topicVector); elements2[w] = Math.exp(elements1[w]); value += elements2[w]; } return value; } } ================================================ FILE: src/utility/CmdArgs.java ================================================ package utility; import org.kohsuke.args4j.Option; public class CmdArgs { @Option(name = "-model", usage = "Specify model", required = true) public String model = ""; @Option(name = "-corpus", usage = "Specify path to topic modeling corpus") public String corpus = ""; @Option(name = "-vectors", usage = "Specify path to the file containing word vectors") public String vectors = ""; @Option(name = "-ntopics", usage = "Specify number of topics") public int ntopics = 20; @Option(name = "-alpha", usage = "Specify alpha") public double alpha = 0.1; @Option(name = "-beta", usage = "Specify beta") public double beta = 0.01; @Option(name = "-lambda", usage = "Specify mixture weight lambda") public double lambda = 0.6; @Option(name = "-initers", usage = "Specify number of initial sampling iterations") public int initers = 2000; @Option(name = "-niters", usage = "Specify number of EM-style sampling iterations") public int niters = 200; @Option(name = "-twords", usage = "Specify number of top topical words") public int twords = 20; @Option(name = "-name", usage = "Specify a name to a topic modeling experiment") public String expModelName = "model"; @Option(name = "-initFile") public String initTopicAssgns = ""; @Option(name = "-sstep") public int savestep = 0; @Option(name = "-dir") public String dir = ""; @Option(name = "-label") public String labelFile = ""; @Option(name = "-prob") public String prob = ""; @Option(name = "-paras", usage = "Specify path to hyper-parameter file") public String paras = ""; } ================================================ FILE: src/utility/FuncUtils.java ================================================ package utility; import java.util.Collections; import java.util.Comparator; import java.util.LinkedHashMap; import java.util.LinkedList; import java.util.List; import java.util.Map; public class FuncUtils { public static > Map sortByValueDescending(Map map) { List> list = new LinkedList>(map.entrySet()); Collections.sort(list, new Comparator>() { @Override public int compare(Map.Entry o1, Map.Entry o2) { int compare = (o1.getValue()).compareTo(o2.getValue()); return -compare; } }); Map result = new LinkedHashMap(); for (Map.Entry entry : list) { result.put(entry.getKey(), entry.getValue()); } return result; } public static > Map sortByValueAscending(Map map) { List> list = new LinkedList>(map.entrySet()); Collections.sort(list, new Comparator>() { @Override public int compare(Map.Entry o1, Map.Entry o2) { int compare = (o1.getValue()).compareTo(o2.getValue()); return compare; } }); Map result = new LinkedHashMap(); for (Map.Entry entry : list) { result.put(entry.getKey(), entry.getValue()); } return result; } /** * Sample a value from a double array * * @param probs * @return */ public static int nextDiscrete(double[] probs) { double sum = 0.0; for (int i = 0; i < probs.length; i++) sum += probs[i]; double r = MTRandom.nextDouble() * sum; sum = 0.0; for (int i = 0; i < probs.length; i++) { sum += probs[i]; if (sum > r) return i; } return probs.length - 1; } public static double mean(double[] m) { double sum = 0; for (int i = 0; i < m.length; i++) sum += m[i]; return sum / m.length; } public static double stddev(double[] m) { double mean = mean(m); double s = 0; for (int i = 0; i < m.length; i++) s += (m[i] - mean) * (m[i] - mean); return Math.sqrt(s / m.length); } } ================================================ FILE: src/utility/LBFGS.java ================================================ /* Copyright (C) 2002 Univ. of Massachusetts Amherst, Computer Science Dept. This file is part of "MALLET" (MAchine Learning for LanguagE Toolkit). http://www.cs.umass.edu/~mccallum/mallet This software is provided under the terms of the Common Public License, version 1.0, as published by http://www.opensource.org. For further information, see the file `LICENSE' included with this distribution. */ /** @author Aron Culotta culotta@cs.umass.edu */ /** Limited Memory BFGS, as described in Byrd, Nocedal, and Schnabel, "Representations of Quasi-Newton Matrices and Their Use in Limited Memory Methods" */ package utility; import java.util.LinkedList; import java.util.logging.Logger; import cc.mallet.optimize.BackTrackLineSearch; import cc.mallet.optimize.InvalidOptimizableException; import cc.mallet.optimize.LineOptimizer; import cc.mallet.optimize.Optimizable; import cc.mallet.optimize.Optimizer; import cc.mallet.optimize.OptimizerEvaluator; import cc.mallet.types.MatrixOps; import cc.mallet.util.MalletLogger; public class LBFGS implements Optimizer { private static Logger logger = MalletLogger .getLogger("edu.umass.cs.mallet.base.ml.maximize.LimitedMemoryBFGS"); boolean converged = false; Optimizable.ByGradientValue optimizable; final int maxIterations = 5000; // xxx need a more principled stopping point // final double tolerance = .0001; // private double tolerance = 1.0e-10; // final double gradientTolerance = 1.0e-10; double tolerance;// = 1.0e-3; final double gradientTolerance = 1.0e-5; final double eps = 1.0e-10; // The number of corrections used in BFGS update // ideally 3 <= m <= 7. Larger m means more cpu time, memory. final int m = 10; // Line search function private LineOptimizer.ByGradient lineMaximizer; public LBFGS(Optimizable.ByGradientValue function, double inTolerance) { tolerance = inTolerance; this.optimizable = function; lineMaximizer = new BackTrackLineSearch(function); } @Override public Optimizable getOptimizable() { return this.optimizable; } @Override public boolean isConverged() { return converged; } /** * Sets the LineOptimizer.ByGradient to use in L-BFGS optimization. * * @param lineOpt * line optimizer for L-BFGS */ public void setLineOptimizer(LineOptimizer.ByGradient lineOpt) { lineMaximizer = lineOpt; } // State of search // g = gradient // s = list of m previous "parameters" values // y = list of m previous "g" values // rho = intermediate calculation double[] g, oldg, direction, parameters, oldParameters; LinkedList s = new LinkedList(); LinkedList y = new LinkedList(); LinkedList rho = new LinkedList(); double[] alpha; static double step = 1.0; int iterations; private OptimizerEvaluator.ByGradient eval = null; // CPAL - added this public void setTolerance(double newtol) { this.tolerance = newtol; } public void setEvaluator(OptimizerEvaluator.ByGradient eval) { this.eval = eval; } public int getIteration() { return iterations; } @Override public boolean optimize() { return optimize(Integer.MAX_VALUE); } @Override public boolean optimize(int numIterations) { double initialValue = optimizable.getValue(); logger.fine("Entering L-BFGS.optimize(). Initial Value=" + initialValue); if (g == null) { // first time through logger.fine("First time through L-BFGS"); iterations = 0; s = new LinkedList(); y = new LinkedList(); rho = new LinkedList(); alpha = new double[m]; for (int i = 0; i < m; i++) alpha[i] = 0.0; parameters = new double[optimizable.getNumParameters()]; oldParameters = new double[optimizable.getNumParameters()]; g = new double[optimizable.getNumParameters()]; oldg = new double[optimizable.getNumParameters()]; direction = new double[optimizable.getNumParameters()]; optimizable.getParameters(parameters); System.arraycopy(parameters, 0, oldParameters, 0, parameters.length); optimizable.getValueGradient(g); System.arraycopy(g, 0, oldg, 0, g.length); System.arraycopy(g, 0, direction, 0, g.length); if (MatrixOps.absNormalize(direction) == 0) { logger.info("L-BFGS initial gradient is zero; saying converged"); g = null; converged = true; return true; } logger.fine("direction.2norm: " + MatrixOps.twoNorm(direction)); MatrixOps.timesEquals(direction, 1.0 / MatrixOps.twoNorm(direction)); // make initial jump logger.fine("before initial jump: \ndirection.2norm: " + MatrixOps.twoNorm(direction) + " \ngradient.2norm: " + MatrixOps.twoNorm(g) + "\nparameters.2norm: " + MatrixOps.twoNorm(parameters)); // TestMaximizable.testValueAndGradientInDirection (maxable, // direction); step = lineMaximizer.optimize(direction, step); if (step == 0.0) {// could not step in this direction. // // give up and say converged. // g = null; // reset search // step = 1.0; // throw new OptimizationException( // "Line search could not step in the current direction. " // + // "(This is not necessarily cause for alarm. Sometimes this happens close to the maximum," // + " where the function may be very flat.)"); return false; } optimizable.getParameters(parameters); optimizable.getValueGradient(g); logger.fine("after initial jump: \ndirection.2norm: " + MatrixOps.twoNorm(direction) + " \ngradient.2norm: " + MatrixOps.twoNorm(g)); } double value = optimizable.getValue(); for (int iterationCount = 0; iterationCount < numIterations; iterationCount++) { logger.fine("L-BFGS iteration=" + iterationCount + ", value=" + value + " g.twoNorm: " + MatrixOps.twoNorm(g) + " oldg.twoNorm: " + MatrixOps.twoNorm(oldg)); // if (iterationCount % 10 == 0) // System.out.println("\t\tL-BFGS iteration=" + iterationCount // + ", value=" + value + " g.twoNorm: " // + MatrixOps.twoNorm(g) + " oldg.twoNorm: " // + MatrixOps.twoNorm(oldg)); // get difference between previous 2 gradients and parameters double sy = 0.0; double yy = 0.0; for (int i = 0; i < oldParameters.length; i++) { // -inf - (-inf) = 0; inf - inf = 0 if (Double.isInfinite(parameters[i]) && Double.isInfinite(oldParameters[i]) && (parameters[i] * oldParameters[i] > 0)) oldParameters[i] = 0.0; else oldParameters[i] = parameters[i] - oldParameters[i]; if (Double.isInfinite(g[i]) && Double.isInfinite(oldg[i]) && (g[i] * oldg[i] > 0)) oldg[i] = 0.0; else oldg[i] = g[i] - oldg[i]; sy += oldParameters[i] * oldg[i]; // si * yi yy += oldg[i] * oldg[i]; direction[i] = g[i]; } if (sy > 0) { throw new InvalidOptimizableException("sy = " + sy + " > 0"); } double gamma = sy / yy; // scaling factor if (gamma > 0) throw new InvalidOptimizableException("gamma = " + gamma + " > 0"); push(rho, 1.0 / sy); push(s, oldParameters); push(y, oldg); // calculate new direction assert (s.size() == y.size()) : "s.size: " + s.size() + " y.size: " + y.size(); for (int i = s.size() - 1; i >= 0; i--) { alpha[i] = ((Double) rho.get(i)).doubleValue() * MatrixOps.dotProduct((double[]) s.get(i), direction); MatrixOps.plusEquals(direction, (double[]) y.get(i), -1.0 * alpha[i]); } MatrixOps.timesEquals(direction, gamma); for (int i = 0; i < y.size(); i++) { double beta = (((Double) rho.get(i)).doubleValue()) * MatrixOps.dotProduct((double[]) y.get(i), direction); MatrixOps.plusEquals(direction, (double[]) s.get(i), alpha[i] - beta); } for (int i = 0; i < oldg.length; i++) { oldParameters[i] = parameters[i]; oldg[i] = g[i]; direction[i] *= -1.0; } logger.fine("before linesearch: direction.gradient.dotprod: " + MatrixOps.dotProduct(direction, g) + "\ndirection.2norm: " + MatrixOps.twoNorm(direction) + "\nparameters.2norm: " + MatrixOps.twoNorm(parameters)); // TestMaximizable.testValueAndGradientInDirection (maxable, // direction); step = lineMaximizer.optimize(direction, step); if (step == 0.0) { // could not step in this direction. g = null; // reset search step = 1.0; // xxx Temporary test; passed OK // TestMaximizable.testValueAndGradientInDirection (maxable, // direction); // System.out // .println("\t\tLine search could not step in the current direction."); // throw new OptimizationException( // "Line search could not step in the current direction. " // + // "(This is not necessarily cause for alarm. Sometimes this happens close to the maximum," // + " where the function may be very flat.)"); return false; } optimizable.getParameters(parameters); optimizable.getValueGradient(g); logger.fine("after linesearch: direction.2norm: " + MatrixOps.twoNorm(direction)); double newValue = optimizable.getValue(); // Test for terminations // if(2.0*Math.abs(newValue-value) <= tolerance* // (Math.abs(newValue)+Math.abs(value) + eps)){ if (Math.abs(newValue - value) <= tolerance) { // System.out.println("\t\tNumber of iterations: " // + iterationCount); // System.out // .println("\t\tExiting L-BFGS on termination #1:\n\t\tvalue difference below " // + tolerance // + " (oldValue: " // + value // + " newValue: " // + newValue // + " gradient.twoNorm: " // + MatrixOps.twoNorm(g) + ")"); converged = true; return true; } value = newValue; double gg = MatrixOps.twoNorm(g); if (gg < gradientTolerance) { logger.fine("Exiting L-BFGS on termination #2: \ngradient=" + gg + " < " + gradientTolerance); converged = true; return true; } if (gg == 0.0) { logger.fine("Exiting L-BFGS on termination #3: \ngradient==0.0"); converged = true; return true; } logger.fine("Gradient = " + gg); iterations++; if (iterations > maxIterations) { System.err .println("Too many iterations in L-BFGS.java. Continuing with current parameters."); converged = true; return true; // throw new IllegalStateException ("Too many iterations."); } // end of iteration. call evaluator if (eval != null && !eval.evaluate(optimizable, iterationCount)) { logger.fine("Exiting L-BFGS on termination #4: evaluator returned false."); converged = true; return false; } } return false; } /** * Resets the previous gradients and values that are used to approximate the Hessian. NOTE - If * the {@link Optimizable} object is modified externally, this method should be called to avoid * IllegalStateExceptions. */ public void reset() { g = null; } /** * Pushes a new object onto the queue l * * @param l * linked list queue of Matrix obj's * @param toadd * matrix to push onto queue */ private void push(LinkedList l, double[] toadd) { assert (l.size() <= m); if (l.size() == m) { // remove oldest matrix and add newset to end of list. // to make this more efficient, actually overwrite // memory of oldest matrix // this overwrites the oldest matrix double[] last = (double[]) l.get(0); System.arraycopy(toadd, 0, last, 0, toadd.length); Object ptr = last; // this readjusts the pointers in the list for (int i = 0; i < l.size() - 1; i++) l.set(i, l.get(i + 1)); l.set(m - 1, ptr); } else { double[] newArray = new double[toadd.length]; System.arraycopy(toadd, 0, newArray, 0, toadd.length); l.addLast(newArray); } } /** * Pushes a new object onto the queue l * * @param l * linked list queue of Double obj's * @param toadd * double value to push onto queue */ private void push(LinkedList l, double toadd) { assert (l.size() <= m); if (l.size() == m) { // pop old double and add new l.removeFirst(); l.addLast(new Double(toadd)); } else l.addLast(new Double(toadd)); } } ================================================ FILE: src/utility/MTRandom.java ================================================ package utility; public class MTRandom { private static MersenneTwister rand = new MersenneTwister(); public static void setSeed(long seed) { rand.setSeed(seed); } public static double nextDouble() { return rand.nextDouble(); } public static int nextInt(int n) { return rand.nextInt(n); } public static boolean nextBoolean() { return rand.nextBoolean(); } } ================================================ FILE: src/utility/MersenneTwister.java ================================================ package utility; import java.io.DataInputStream; import java.io.DataOutputStream; import java.io.IOException; import java.io.ObjectInputStream; import java.io.ObjectOutputStream; import java.io.Serializable; /** *

MersenneTwister and MersenneTwisterFast

*

* Version 20, based on version MT199937(99/10/29) of the Mersenne Twister algorithm found at * The Mersenne Twister Home Page, with * the initialization improved using the new 2002/1/26 initialization algorithm By Sean Luke, * October 2004. * *

* MersenneTwister is a drop-in subclass replacement for java.util.Random. It is properly * synchronized and can be used in a multithreaded environment. On modern VMs such as HotSpot, it is * approximately 1/3 slower than java.util.Random. * *

* MersenneTwisterFast is not a subclass of java.util.Random. It has the same public methods * as Random does, however, and it is algorithmically identical to MersenneTwister. * MersenneTwisterFast has hard-code inlined all of its methods directly, and made all of them final * (well, the ones of consequence anyway). Further, these methods are not synchronized, so * the same MersenneTwisterFast instance cannot be shared by multiple threads. But all this helps * MersenneTwisterFast achieve well over twice the speed of MersenneTwister. java.util.Random is * about 1/3 slower than MersenneTwisterFast. * *

About the Mersenne Twister

*

* This is a Java version of the C-program for MT19937: Integer version. The MT19937 algorithm was * created by Makoto Matsumoto and Takuji Nishimura, who ask: "When you use this, send an email to: * matumoto@math.keio.ac.jp with an appropriate reference to your work". Indicate that this is a * translation of their algorithm into Java. * *

* Reference. Makato Matsumoto and Takuji Nishimura, "Mersenne Twister: A 623-Dimensionally * Equidistributed Uniform Pseudo-Random Number Generator", ACM Transactions on Modeling and. * Computer Simulation, Vol. 8, No. 1, January 1998, pp 3--30. * *

About this Version

* *

* Changes since V19: nextFloat(boolean, boolean) now returns float, not double. * *

* Changes since V18: Removed old final declarations, which used to potentially speed up the * code, but no longer. * *

* Changes since V17: Removed vestigial references to &= 0xffffffff which stemmed from the * original C code. The C code could not guarantee that ints were 32 bit, hence the masks. The * vestigial references in the Java code were likely optimized out anyway. * *

* Changes since V16: Added nextDouble(includeZero, includeOne) and nextFloat(includeZero, * includeOne) to allow for half-open, fully-closed, and fully-open intervals. * *

* Changes Since V15: Added serialVersionUID to quiet compiler warnings from Sun's overly * verbose compilers as of JDK 1.5. * *

* Changes Since V14: made strictfp, with StrictMath.log and StrictMath.sqrt in nextGaussian * instead of Math.log and Math.sqrt. This is largely just to be safe, as it presently makes no * difference in the speed, correctness, or results of the algorithm. * *

* Changes Since V13: clone() method CloneNotSupportedException removed. * *

* Changes Since V12: clone() method added. * *

* Changes Since V11: stateEquals(...) method added. MersenneTwisterFast is equal to other * MersenneTwisterFasts with identical state; likewise MersenneTwister is equal to other * MersenneTwister with identical state. This isn't equals(...) because that requires a contract of * immutability to compare by value. * *

* Changes Since V10: A documentation error suggested that setSeed(int[]) required an int[] * array 624 long. In fact, the array can be any non-zero length. The new version also checks for * this fact. * *

* Changes Since V9: readState(stream) and writeState(stream) provided. * *

* Changes Since V8: setSeed(int) was only using the first 28 bits of the seed; it should * have been 32 bits. For small-number seeds the behavior is identical. * *

* Changes Since V7: A documentation error in MersenneTwisterFast (but not MersenneTwister) * stated that nextDouble selects uniformly from the full-open interval [0,1]. It does not. * nextDouble's contract is identical across MersenneTwisterFast, MersenneTwister, and * java.util.Random, namely, selection in the half-open interval [0,1). That is, 1.0 should not be * returned. A similar contract exists in nextFloat. * *

* Changes Since V6: License has changed from LGPL to BSD. New timing information to compare * against java.util.Random. Recent versions of HotSpot have helped Random increase in speed to the * point where it is faster than MersenneTwister but slower than MersenneTwisterFast (which should * be the case, as it's a less complex algorithm but is synchronized). * *

* Changes Since V5: New empty constructor made to work the same as java.util.Random -- * namely, it seeds based on the current time in milliseconds. * *

* Changes Since V4: New initialization algorithms. See (see * http://www.math.keio.ac.jp/matumoto/MT2002/emt19937ar.html) * *

* The MersenneTwister code is based on standard MT19937 C/C++ code by Takuji Nishimura, with * suggestions from Topher Cooper and Marc Rieffel, July 1997. The code was originally translated * into Java by Michael Lecuyer, January 1999, and the original code is Copyright (c) 1999 by * Michael Lecuyer. * *

Java notes

* *

* This implementation implements the bug fixes made in Java 1.2's version of Random, which means it * can be used with earlier versions of Java. See the JDK 1.2 * java.util.Random documentation for further documentation on the random-number generation * contracts made. Additionally, there's an undocumented bug in the JDK java.util.Random.nextBytes() * method, which this code fixes. * *

* Just like java.util.Random, this generator accepts a long seed but doesn't use all of it. * java.util.Random uses 48 bits. The Mersenne Twister instead uses 32 bits (int size). So it's best * if your seed does not exceed the int range. * *

* MersenneTwister can be used reliably on JDK version 1.1.5 or above. Earlier Java versions have * serious bugs in java.util.Random; only MersenneTwisterFast (and not MersenneTwister nor * java.util.Random) should be used with them. * *

License

* * Copyright (c) 2003 by Sean Luke.
* Portions copyright (c) 1993 by Michael Lecuyer.
* All rights reserved.
* *

* Redistribution and use in source and binary forms, with or without modification, are permitted * provided that the following conditions are met: *

    *
  • Redistributions of source code must retain the above copyright notice, this list of * conditions and the following disclaimer. *
  • Redistributions in binary form must reproduce the above copyright notice, this list of * conditions and the following disclaimer in the documentation and/or other materials provided with * the distribution. *
  • Neither the name of the copyright owners, their employers, nor the names of its contributors * may be used to endorse or promote products derived from this software without specific prior * written permission. *
*

* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND ANY EXPRESS OR * IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND * FITNESS FOR A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNERS OR * CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, * DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, * WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY * WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. * * @version 20 */ public strictfp class MersenneTwister extends java.util.Random implements Serializable, Cloneable { // Serialization private static final long serialVersionUID = -4035832775130174188L; // locked as of Version 15 // Period parameters private static final int N = 624; private static final int M = 397; private static final int MATRIX_A = 0x9908b0df; // private static final * constant vector a private static final int UPPER_MASK = 0x80000000; // most significant w-r bits private static final int LOWER_MASK = 0x7fffffff; // least significant r bits // Tempering parameters private static final int TEMPERING_MASK_B = 0x9d2c5680; private static final int TEMPERING_MASK_C = 0xefc60000; private int mt[]; // the array for the state vector private int mti; // mti==N+1 means mt[N] is not initialized private int mag01[]; // a good initial seed (of int size, though stored in a long) // private static final long GOOD_SEED = 4357; /* * implemented here because there's a bug in Random's implementation of the Gaussian code * (divide by zero, and log(0), ugh!), yet its gaussian variables are private so we can't access * them here. :-( */ private double __nextNextGaussian; private boolean __haveNextNextGaussian; /* We're overriding all internal data, to my knowledge, so this should be okay */ public Object clone() { try { MersenneTwister f = (MersenneTwister) (super.clone()); f.mt = (int[]) (mt.clone()); f.mag01 = (int[]) (mag01.clone()); return f; } catch (CloneNotSupportedException e) { throw new InternalError(); } // should never happen } public boolean stateEquals(Object o) { if (o == this) return true; if (o == null || !(o instanceof MersenneTwister)) return false; MersenneTwister other = (MersenneTwister) o; if (mti != other.mti) return false; for (int x = 0; x < mag01.length; x++) if (mag01[x] != other.mag01[x]) return false; for (int x = 0; x < mt.length; x++) if (mt[x] != other.mt[x]) return false; return true; } /** Reads the entire state of the MersenneTwister RNG from the stream */ public void readState(DataInputStream stream) throws IOException { int len = mt.length; for (int x = 0; x < len; x++) mt[x] = stream.readInt(); len = mag01.length; for (int x = 0; x < len; x++) mag01[x] = stream.readInt(); mti = stream.readInt(); __nextNextGaussian = stream.readDouble(); __haveNextNextGaussian = stream.readBoolean(); } /** Writes the entire state of the MersenneTwister RNG to the stream */ public void writeState(DataOutputStream stream) throws IOException { int len = mt.length; for (int x = 0; x < len; x++) stream.writeInt(mt[x]); len = mag01.length; for (int x = 0; x < len; x++) stream.writeInt(mag01[x]); stream.writeInt(mti); stream.writeDouble(__nextNextGaussian); stream.writeBoolean(__haveNextNextGaussian); } /** * Constructor using the default seed. */ public MersenneTwister() { this(System.currentTimeMillis()); } /** * Constructor using a given seed. Though you pass this seed in as a long, it's best to make * sure it's actually an integer. */ public MersenneTwister(long seed) { super(seed); /* just in case */ setSeed(seed); } /** * Constructor using an array of integers as seed. Your array must have a non-zero length. Only * the first 624 integers in the array are used; if the array is shorter than this then integers * are repeatedly used in a wrap-around fashion. */ public MersenneTwister(int[] array) { super(System.currentTimeMillis()); /* pick something at random just in case */ setSeed(array); } /** * Initalize the pseudo random number generator. Don't pass in a long that's bigger than an int * (Mersenne Twister only uses the first 32 bits for its seed). */ synchronized public void setSeed(long seed) { // it's always good style to call super super.setSeed(seed); // Due to a bug in java.util.Random clear up to 1.2, we're // doing our own Gaussian variable. __haveNextNextGaussian = false; mt = new int[N]; mag01 = new int[2]; mag01[0] = 0x0; mag01[1] = MATRIX_A; mt[0] = (int) (seed & 0xffffffff); mt[0] = (int) seed; for (mti = 1; mti < N; mti++) { mt[mti] = (1812433253 * (mt[mti - 1] ^ (mt[mti - 1] >>> 30)) + mti); /* See Knuth TAOCP Vol2. 3rd Ed. P.106 for multiplier. */ /* In the previous versions, MSBs of the seed affect */ /* only MSBs of the array mt[]. */ /* 2002/01/09 modified by Makoto Matsumoto */ // mt[mti] &= 0xffffffff; /* for >32 bit machines */ } } /** * Sets the seed of the MersenneTwister using an array of integers. Your array must have a * non-zero length. Only the first 624 integers in the array are used; if the array is shorter * than this then integers are repeatedly used in a wrap-around fashion. */ synchronized public void setSeed(int[] array) { if (array.length == 0) throw new IllegalArgumentException("Array length must be greater than zero"); int i, j, k; setSeed(19650218); i = 1; j = 0; k = (N > array.length ? N : array.length); for (; k != 0; k--) { mt[i] = (mt[i] ^ ((mt[i - 1] ^ (mt[i - 1] >>> 30)) * 1664525)) + array[j] + j; /* * non * linear */ // mt[i] &= 0xffffffff; /* for WORDSIZE > 32 machines */ i++; j++; if (i >= N) { mt[0] = mt[N - 1]; i = 1; } if (j >= array.length) j = 0; } for (k = N - 1; k != 0; k--) { mt[i] = (mt[i] ^ ((mt[i - 1] ^ (mt[i - 1] >>> 30)) * 1566083941)) - i; /* non linear */ // mt[i] &= 0xffffffff; /* for WORDSIZE > 32 machines */ i++; if (i >= N) { mt[0] = mt[N - 1]; i = 1; } } mt[0] = 0x80000000; /* MSB is 1; assuring non-zero initial array */ } /** * Returns an integer with bits bits filled with a random number. */ synchronized protected int next(int bits) { int y; if (mti >= N) // generate N words at one time { int kk; final int[] mt = this.mt; // locals are slightly faster final int[] mag01 = this.mag01; // locals are slightly faster for (kk = 0; kk < N - M; kk++) { y = (mt[kk] & UPPER_MASK) | (mt[kk + 1] & LOWER_MASK); mt[kk] = mt[kk + M] ^ (y >>> 1) ^ mag01[y & 0x1]; } for (; kk < N - 1; kk++) { y = (mt[kk] & UPPER_MASK) | (mt[kk + 1] & LOWER_MASK); mt[kk] = mt[kk + (M - N)] ^ (y >>> 1) ^ mag01[y & 0x1]; } y = (mt[N - 1] & UPPER_MASK) | (mt[0] & LOWER_MASK); mt[N - 1] = mt[M - 1] ^ (y >>> 1) ^ mag01[y & 0x1]; mti = 0; } y = mt[mti++]; y ^= y >>> 11; // TEMPERING_SHIFT_U(y) y ^= (y << 7) & TEMPERING_MASK_B; // TEMPERING_SHIFT_S(y) y ^= (y << 15) & TEMPERING_MASK_C; // TEMPERING_SHIFT_T(y) y ^= (y >>> 18); // TEMPERING_SHIFT_L(y) return y >>> (32 - bits); // hope that's right! } /* * If you've got a truly old version of Java, you can omit these two next methods. */ private synchronized void writeObject(ObjectOutputStream out) throws IOException { // just so we're synchronized. out.defaultWriteObject(); } private synchronized void readObject(ObjectInputStream in) throws IOException, ClassNotFoundException { // just so we're synchronized. in.defaultReadObject(); } /** * This method is missing from jdk 1.0.x and below. JDK 1.1 includes this for us, but what the * heck. */ public boolean nextBoolean() { return next(1) != 0; } /** * This generates a coin flip with a probability probability of returning true, else * returning false. probability must be between 0.0 and 1.0, inclusive. Not as precise * a random real event as nextBoolean(double), but twice as fast. To explicitly use this, * remember you may need to cast to float first. */ public boolean nextBoolean(float probability) { if (probability < 0.0f || probability > 1.0f) throw new IllegalArgumentException("probability must be between 0.0 and 1.0 inclusive."); if (probability == 0.0f) return false; // fix half-open issues else if (probability == 1.0f) return true; // fix half-open issues return nextFloat() < probability; } /** * This generates a coin flip with a probability probability of returning true, else * returning false. probability must be between 0.0 and 1.0, inclusive. */ public boolean nextBoolean(double probability) { if (probability < 0.0 || probability > 1.0) throw new IllegalArgumentException("probability must be between 0.0 and 1.0 inclusive."); if (probability == 0.0) return false; // fix half-open issues else if (probability == 1.0) return true; // fix half-open issues return nextDouble() < probability; } /** * This method is missing from JDK 1.1 and below. JDK 1.2 includes this for us, but what the * heck. */ public int nextInt(int n) { if (n <= 0) throw new IllegalArgumentException("n must be positive, got: " + n); if ((n & -n) == n) return (int) ((n * (long) next(31)) >> 31); int bits, val; do { bits = next(31); val = bits % n; } while (bits - val + (n - 1) < 0); return val; } /** * This method is for completness' sake. Returns a long drawn uniformly from 0 to n-1. Suffice * it to say, n must be > 0, or an IllegalArgumentException is raised. */ public long nextLong(long n) { if (n <= 0) throw new IllegalArgumentException("n must be positive, got: " + n); long bits, val; do { bits = (nextLong() >>> 1); val = bits % n; } while (bits - val + (n - 1) < 0); return val; } /** * A bug fix for versions of JDK 1.1 and below. JDK 1.2 fixes this for us, but what the heck. */ public double nextDouble() { return (((long) next(26) << 27) + next(27)) / (double) (1L << 53); } /** * Returns a double in the range from 0.0 to 1.0, possibly inclusive of 0.0 and 1.0 themselves. * Thus: * *

* * * * * *
* Expression * Interval *
nextDouble(false, false) * (0.0, 1.0) *
nextDouble(true, false) * [0.0, 1.0) *
nextDouble(false, true) * (0.0, 1.0] *
nextDouble(true, true) * [0.0, 1.0] *
* *

* This version preserves all possible random values in the double range. */ public double nextDouble(boolean includeZero, boolean includeOne) { double d = 0.0; do { d = nextDouble(); // grab a value, initially from half-open [0.0, 1.0) if (includeOne && nextBoolean()) d += 1.0; // if includeOne, with 1/2 probability, push to [1.0, 2.0) } while ((d > 1.0) || // everything above 1.0 is always invalid (!includeZero && d == 0.0)); // if we're not including zero, 0.0 is invalid return d; } /** * A bug fix for versions of JDK 1.1 and below. JDK 1.2 fixes this for us, but what the heck. */ public float nextFloat() { return next(24) / ((float) (1 << 24)); } /** * Returns a float in the range from 0.0f to 1.0f, possibly inclusive of 0.0f and 1.0f * themselves. Thus: * *

* * * * * *
* Expression * Interval *
nextFloat(false, false) * (0.0f, 1.0f) *
nextFloat(true, false) * [0.0f, 1.0f) *
nextFloat(false, true) * (0.0f, 1.0f] *
nextFloat(true, true) * [0.0f, 1.0f] *
* *

* This version preserves all possible random values in the float range. */ public float nextFloat(boolean includeZero, boolean includeOne) { float d = 0.0f; do { d = nextFloat(); // grab a value, initially from half-open [0.0f, 1.0f) if (includeOne && nextBoolean()) d += 1.0f; // if includeOne, with 1/2 probability, push to [1.0f, 2.0f) } while ((d > 1.0f) || // everything above 1.0f is always invalid (!includeZero && d == 0.0f)); // if we're not including zero, 0.0f is invalid return d; } /** * A bug fix for all versions of the JDK. The JDK appears to use all four bytes in an integer as * independent byte values! Totally wrong. I've submitted a bug report. */ public void nextBytes(byte[] bytes) { for (int x = 0; x < bytes.length; x++) bytes[x] = (byte) next(8); } /** For completeness' sake, though it's not in java.util.Random. */ public char nextChar() { // chars are 16-bit UniCode values return (char) (next(16)); } /** For completeness' sake, though it's not in java.util.Random. */ public short nextShort() { return (short) (next(16)); } /** For completeness' sake, though it's not in java.util.Random. */ public byte nextByte() { return (byte) (next(8)); } /** * A bug fix for all JDK code including 1.2. nextGaussian can theoretically ask for the log of 0 * and divide it by 0! See Java bug * http://developer.java.sun.com/developer/bugParade/bugs/4254501.html */ synchronized public double nextGaussian() { if (__haveNextNextGaussian) { __haveNextNextGaussian = false; return __nextNextGaussian; } else { double v1, v2, s; do { v1 = 2 * nextDouble() - 1; // between -1.0 and 1.0 v2 = 2 * nextDouble() - 1; // between -1.0 and 1.0 s = v1 * v1 + v2 * v2; } while (s >= 1 || s == 0); double multiplier = StrictMath.sqrt(-2 * StrictMath.log(s) / s); __nextNextGaussian = v2 * multiplier; __haveNextNextGaussian = true; return v1 * multiplier; } } /** * Tests the code. */ public static void main(String args[]) { int j; MersenneTwister r; // CORRECTNESS TEST // COMPARE WITH http://www.math.keio.ac.jp/matumoto/CODES/MT2002/mt19937ar.out r = new MersenneTwister(new int[] { 0x123, 0x234, 0x345, 0x456 }); System.out.println("Output of MersenneTwister with new (2002/1/26) seeding mechanism"); for (j = 0; j < 1000; j++) { // first, convert the int from signed to "unsigned" long l = (long) r.nextInt(); if (l < 0) l += 4294967296L; // max int value String s = String.valueOf(l); while (s.length() < 10) s = " " + s; // buffer System.out.print(s + " "); if (j % 5 == 4) System.out.println(); } // SPEED TEST final long SEED = 4357; int xx; long ms; System.out.println("\nTime to test grabbing 100000000 ints"); r = new MersenneTwister(SEED); ms = System.currentTimeMillis(); xx = 0; for (j = 0; j < 100000000; j++) xx += r.nextInt(); System.out.println("Mersenne Twister: " + (System.currentTimeMillis() - ms) + " Ignore this: " + xx); System.out .println("To compare this with java.util.Random, run this same test on MersenneTwisterFast."); System.out .println("The comparison with Random is removed from MersenneTwister because it is a proper"); System.out .println("subclass of Random and this unfairly makes some of Random's methods un-inlinable,"); System.out.println("so it would make Random look worse than it is."); // TEST TO COMPARE TYPE CONVERSION BETWEEN // MersenneTwisterFast.java AND MersenneTwister.java System.out.println("\nGrab the first 1000 booleans"); r = new MersenneTwister(SEED); for (j = 0; j < 1000; j++) { System.out.print(r.nextBoolean() + " "); if (j % 8 == 7) System.out.println(); } if (!(j % 8 == 7)) System.out.println(); System.out .println("\nGrab 1000 booleans of increasing probability using nextBoolean(double)"); r = new MersenneTwister(SEED); for (j = 0; j < 1000; j++) { System.out.print(r.nextBoolean((double) (j / 999.0)) + " "); if (j % 8 == 7) System.out.println(); } if (!(j % 8 == 7)) System.out.println(); System.out .println("\nGrab 1000 booleans of increasing probability using nextBoolean(float)"); r = new MersenneTwister(SEED); for (j = 0; j < 1000; j++) { System.out.print(r.nextBoolean((float) (j / 999.0f)) + " "); if (j % 8 == 7) System.out.println(); } if (!(j % 8 == 7)) System.out.println(); byte[] bytes = new byte[1000]; System.out.println("\nGrab the first 1000 bytes using nextBytes"); r = new MersenneTwister(SEED); r.nextBytes(bytes); for (j = 0; j < 1000; j++) { System.out.print(bytes[j] + " "); if (j % 16 == 15) System.out.println(); } if (!(j % 16 == 15)) System.out.println(); byte b; System.out.println("\nGrab the first 1000 bytes -- must be same as nextBytes"); r = new MersenneTwister(SEED); for (j = 0; j < 1000; j++) { System.out.print((b = r.nextByte()) + " "); if (b != bytes[j]) System.out.print("BAD "); if (j % 16 == 15) System.out.println(); } if (!(j % 16 == 15)) System.out.println(); System.out.println("\nGrab the first 1000 shorts"); r = new MersenneTwister(SEED); for (j = 0; j < 1000; j++) { System.out.print(r.nextShort() + " "); if (j % 8 == 7) System.out.println(); } if (!(j % 8 == 7)) System.out.println(); System.out.println("\nGrab the first 1000 ints"); r = new MersenneTwister(SEED); for (j = 0; j < 1000; j++) { System.out.print(r.nextInt() + " "); if (j % 4 == 3) System.out.println(); } if (!(j % 4 == 3)) System.out.println(); System.out.println("\nGrab the first 1000 ints of different sizes"); r = new MersenneTwister(SEED); int max = 1; for (j = 0; j < 1000; j++) { System.out.print(r.nextInt(max) + " "); max *= 2; if (max <= 0) max = 1; if (j % 4 == 3) System.out.println(); } if (!(j % 4 == 3)) System.out.println(); System.out.println("\nGrab the first 1000 longs"); r = new MersenneTwister(SEED); for (j = 0; j < 1000; j++) { System.out.print(r.nextLong() + " "); if (j % 3 == 2) System.out.println(); } if (!(j % 3 == 2)) System.out.println(); System.out.println("\nGrab the first 1000 longs of different sizes"); r = new MersenneTwister(SEED); long max2 = 1; for (j = 0; j < 1000; j++) { System.out.print(r.nextLong(max2) + " "); max2 *= 2; if (max2 <= 0) max2 = 1; if (j % 4 == 3) System.out.println(); } if (!(j % 4 == 3)) System.out.println(); System.out.println("\nGrab the first 1000 floats"); r = new MersenneTwister(SEED); for (j = 0; j < 1000; j++) { System.out.print(r.nextFloat() + " "); if (j % 4 == 3) System.out.println(); } if (!(j % 4 == 3)) System.out.println(); System.out.println("\nGrab the first 1000 doubles"); r = new MersenneTwister(SEED); for (j = 0; j < 1000; j++) { System.out.print(r.nextDouble() + " "); if (j % 3 == 2) System.out.println(); } if (!(j % 3 == 2)) System.out.println(); System.out.println("\nGrab the first 1000 gaussian doubles"); r = new MersenneTwister(SEED); for (j = 0; j < 1000; j++) { System.out.print(r.nextGaussian() + " "); if (j % 3 == 2) System.out.println(); } if (!(j % 3 == 2)) System.out.println(); } } ================================================ FILE: src/utility/Parallel.java ================================================ package utility; import java.util.Collection; import java.util.concurrent.ConcurrentHashMap; import java.util.concurrent.ForkJoinPool; import java.util.concurrent.RecursiveAction; import java.util.concurrent.RecursiveTask; /** * Utilities for parallel computing in loops over independent tasks. This class * provides convenient methods for parallel processing of tasks that involve * loops over indices, in which computations for different indices are * independent. *

* As a simple example, consider the following function that squares floats in * one array and stores the results in a second array. * *

 * 
 * static void sqr(float[] a, float[] b) {
 *   int n = a.length;
 *   for (int i=0; i<n; ++i)
 *     b[i] = a[i]*a[i];
 * }
 * 
 * 
* * A serial version of a similar function for 2D arrays is: * *
 * 
 * static void sqrSerial(float[][] a, float[][] b) 
 * {
 *   int n = a.length;
 *   for (int i=0; i<n; ++i) {
 *     sqr(a[i],b[i]);
 * }
 * 
 * 
* * Using this class, the parallel version for 2D arrays is: * *
 * 
 * static void sqrParallel(final float[][] a, final float[][] b) {
 *   int n = a.length;
 *   Parallel.loop(n,new Parallel.LoopInt() {
 *     public void compute(int i) {
 *       sqr(a[i],b[i]);
 *     }
 *   });
 * }
 * 
 * 
* * In the parallel version, the method {@code compute} defined by the interface * {@code LoopInt} will be called n times for different indices i in the range * [0,n-1]. The order of indices is both indeterminant and irrelevant because * the computation for each index i is independent. The arrays a and b are * declared final as required for use in the implementation of {@code LoopInt}. *

* Note: because the method {@code loop} and interface {@code LoopInt} are * static members of this class, we can omit the class name prefix * {@code Parallel} if we first import these names with * *

 * 
 * import static edu.mines.jtk.util.Parallel.*;
 * 
 * 
* * A similar method facilitates tasks that reduce a sequence of indexed values * to one or more values. For example, given the following method: * *
 * 
 * static float sum(float[] a) {
 *   int n = a.length;
 *   float s = 0.0f;
 *   for (int i=0; i<n; ++i)
 *     s += a[i];
 *   return s;
 * }
 * 
 * 
* * serial and parallel versions for 2D arrays may be written as: * *
 * 
 * static float sumSerial(float[][] a) {
 *   int n = a.length;
 *   float s = 0.0f;
 *   for (int i=0; i<n; ++i)
 *     s += sum(a[i]);
 *   return s;
 * }
 * 
 * 
* * and * *
 * 
 * static float sumParallel(final float[][] a) {
 *   int n = a.length;
 *   return Parallel.reduce(n,new Parallel.ReduceInt<Float>() {
 *     public Float compute(int i) {
 *       return sum(a[i]);
 *     }
 *     public Float combine(Float s1, Float s2) {
 *       return s1+s2;
 *     }
 *   });
 * }
 * 
 * 
* * In the parallel version, we implement the interface {@code ReduceInt} with * two methods, one to {@code compute} sums of array elements and another to * {@code combine} two such sums together. The same pattern works for other * reduce operations. For example, with similar functions we could compute * minimum and maximum values (in a single reduce) for any indexed sequence of * values. *

* More general loops are supported, and are equivalent to the following serial * code: * *

 * 
 * for (int i=begin; i<end; i+=step)
 *   // some computation that depends on i
 * 
 * 
* * The methods loop and reduce require that begin is less than end and that step * is positive. The requirement that begin is less than end ensures that reduce * is always well-defined. The requirement that step is positive ensures that * the loop terminates. *

* Static methods loop and reduce submit tasks to a fork-join framework that * maintains a pool of threads shared by all users of these methods. These * methods recursively split tasks so that disjoint sets of indices are * processed in parallel by different threads. *

* In addition to the three loop parameters begin, end, and step, a fourth * parameter chunk may be specified. This chunk parameter is a threshold for * splitting tasks so that they can be performed in parallel. If a range of * indices to be processed is smaller than the chunk size, or if too many tasks * have already been queued for processing, then the indices are processed * serially. Otherwise, the range is split into two parts for processing by new * tasks. If specified, the chunk size is a lower bound; the number of indices * processed serially will never be lower, but may be higher, than a specified * chunk size. The default chunk size is one. *

* The default chunk size is often sufficient, because the test for an excess * number of queued tasks prevents tasks from being split needlessly. This test * is especially useful when parallel loops are nested, as when looping over * elements of multi-dimensional arrays. *

* For example, an implementation of the method {@code sqrParallel} for 3D * arrays could simply call the 2D version listed above. Tasks will naturally * tend to be split for outer loops, but not inner loops, thereby reducing * overhead, time spent splitting and queueing tasks. *

* Reference: A Java Fork/Join Framework, by Doug Lea, describes the framework * used to implement this class. This framework will be part of JDK 7. * * @author Dave Hale, Colorado School of Mines * @version 2010.11.23 */ public class Parallel { /** A loop body that computes something for an int index. */ public interface LoopInt { /** * Computes for the specified loop index. * * @param i * loop index. */ public void compute(int i); } /** A loop body that computes and returns a value for an int index. */ public interface ReduceInt { /** * Returns a value computed for the specified loop index. * * @param i * loop index. * @return the computed value. */ public V compute(int i); /** * Returns the combination of two specified values. * * @param v1 * a value. * @param v2 * a value. * @return the combined value. */ public V combine(V v1, V v2); } /** * A wrapper for objects that are not thread-safe. Such objects have methods * that cannot safely be executed concurrently in multiple threads. To use * an unsafe object within a parallel computation, first construct an * instance of this wrapper. Then, within the compute method, get the unsafe * object; if null, construct and set a new unsafe object in this wrapper, * before using the unsafe object to perform the computation. This pattern * ensures that each thread computes using a distinct unsafe object. For * example, * *

	 * 
	 * final Parallel.Unsafe<Worker> nts = new Parallel.Unsafe<Worker>();
	 * Parallel.loop(count,new Parallel.LoopInt() {
	 *   public void compute(int i) {
	 *     Worker w = nts.get(); // get worker for the current thread
	 *     if (w==null) nts.set(w=new Worker()); // if null, make one
	 *     w.work(); // the method work need not be thread-safe
	 *   }
	 * });
	 * 
	 * 
* * This wrapper is most useful when (1) the cost of constructing an unsafe * object is high, relative to the cost of each call to compute, and (2) the * number of threads calling compute is significantly lower than the total * number of such calls. Otherwise, if either of these conditions is false, * then simply construct a new unsafe object within the compute method. *

* This wrapper works much like the Java standard class ThreadLocal, except * that an object within this wrapper can be garbage-collected before its * thread dies. This difference is important because fork-join worker * threads are pooled and will typically die only when a program ends. */ public static class Unsafe { /** * Constructs a wrapper for objects that are not thread-safe. */ public Unsafe() { int initialCapacity = 16; // the default initial capacity float loadFactor = 0.5f; // huge numbers of threads are unlikely int concurrencyLevel = 2 * _pool.getParallelism(); _map = new ConcurrentHashMap(initialCapacity, loadFactor, concurrencyLevel); } /** * Gets the object in this wrapper for the current thread. * * @return the object; null, of not yet set for the current thread. */ public T get() { return _map.get(Thread.currentThread()); } /** * Sets the object in this wrapper for the current thread. * * @param object * the object. */ public void set(T object) { _map.put(Thread.currentThread(), object); } /** * Returns a collection of all unsafe objects in this wrapper. This * method is useful only after parallel loops have ended. * * @return the collection of unsafe objects. */ public Collection getAll() { return _map.values(); } private final ConcurrentHashMap _map; } /** * Performs a loop for (int i=0; i<end; ++i). * * @param end * the end index (not included) for the loop. * @param body * the loop body. */ public static void loop(int end, LoopInt body) { loop(0, end, 1, 1, body); } /** * Performs a loop for (int i=begin; i<end; ++i). * * @param begin * the begin index for the loop; must be less than end. * @param end * the end index (not included) for the loop. * @param body * the loop body. */ public static void loop(int begin, int end, LoopInt body) { loop(begin, end, 1, 1, body); } /** * Performs a loop for (int i=begin; i<end; i+=step). * * @param begin * the begin index for the loop; must be less than end. * @param end * the end index (not included) for the loop. * @param step * the index increment; must be positive. * @param body * the loop body. */ public static void loop(int begin, int end, int step, LoopInt body) { loop(begin, end, step, 1, body); } /** * Performs a loop for (int i=begin; i<end; i+=step). * * @param begin * the begin index for the loop; must be less than end. * @param end * the end index (not included) for the loop. * @param step * the index increment; must be positive. * @param chunk * the chunk size; must be positive. * @param body * the loop body. */ public static void loop(int begin, int end, int step, int chunk, LoopInt body) { checkArgs(begin, end, step, chunk); if (_serial || end <= begin + chunk * step) { for (int i = begin; i < end; i += step) { body.compute(i); } } else { LoopIntAction task = new LoopIntAction(begin, end, step, chunk, body); if (LoopIntAction.inForkJoinPool()) { task.invoke(); } else { _pool.invoke(task); } } } /** * Performs a reduce for (int i=0; i<end; ++i). * * @param end * the end index (not included) for the loop. * @param body * the loop body. * @return the computed value. */ public static V reduce(int end, ReduceInt body) { return reduce(0, end, 1, 1, body); } /** * Performs a reduce for (int i=begin; i<end; ++i). * * @param begin * the begin index for the loop; must be less than end. * @param end * the end index (not included) for the loop. * @param body * the loop body. * @return the computed value. */ public static V reduce(int begin, int end, ReduceInt body) { return reduce(begin, end, 1, 1, body); } /** * Performs a reduce for (int i=begin; i<end; i+=step). * * @param begin * the begin index for the loop; must be less than end. * @param end * the end index (not included) for the loop. * @param step * the index increment; must be positive. * @param body * the loop body. * @return the computed value. */ public static V reduce(int begin, int end, int step, ReduceInt body) { return reduce(begin, end, step, 1, body); } /** * Performs a reduce for (int i=begin; i<end; i+=step). * * @param begin * the begin index for the loop; must be less than end. * @param end * the end index (not included) for the loop. * @param step * the index increment; must be positive. * @param chunk * the chunk size; must be positive. * @param body * the loop body. * @return the computed value. */ public static V reduce(int begin, int end, int step, int chunk, ReduceInt body) { checkArgs(begin, end, step, chunk); if (_serial || end <= begin + chunk * step) { V v = body.compute(begin); for (int i = begin + step; i < end; i += step) { V vi = body.compute(i); v = body.combine(v, vi); } return v; } else { ReduceIntTask task = new ReduceIntTask(begin, end, step, chunk, body); if (ReduceIntTask.inForkJoinPool()) { return task.invoke(); } else { return _pool.invoke(task); } } } /** * Enables or disables parallel processing by all methods of this class. By * default, parallel processing is enabled. If disabled, all tasks will be * executed on the current thread. *

* Setting this flag to false disables parallel processing for all * users of this class. This method should therefore be used for * testing and benchmarking only. * * @param parallel * true, for parallel processing; false, otherwise. */ public static void setParallel(boolean parallel) { _serial = !parallel; } // ///////////////////////////////////////////////////////////////////////// // private // Implementation notes: // Each fork-join task below has a range of indices to be processed. // If the range is less than or equal to the chunk size, or if the // queue for the current thread holds too many tasks already, then // simply process the range on the current thread. Otherwise, split // the range into two parts that are approximately equal, ensuring // that the left part is at least as large as the right part. If the // right part is not empty, fork a new task. Then compute the left // part in the current thread, and, if necessary, join the right part. // Threshold for number of surplus queued tasks. Used below to // determine whether or not to split a task into two subtasks. private static final int NSQT = 6; // The pool shared by all fork-join tasks created through this class. private static ForkJoinPool _pool = new ForkJoinPool(); // Serial flag; true for no parallel processing. private static boolean _serial = false; /** * Checks loop arguments. */ private static void checkArgs(int begin, int end, int step, int chunk) { argument(begin < end, "begin 0, "step>0"); argument(chunk > 0, "chunk>0"); } public static void argument(boolean condition, String message) { if (!condition) throw new IllegalArgumentException("required condition: " + message); } /** * Splits range [begin:end) into [begin:middle) and [middle:end). The * returned middle index equals begin plus an integer multiple of step. */ private static int middle(int begin, int end, int step) { return begin + step + ((end - begin - 1) / 2) / step * step; } /** * Fork-join task for parallel loop. */ private static class LoopIntAction extends RecursiveAction { LoopIntAction(int begin, int end, int step, int chunk, LoopInt body) { assert begin < end : "begin < end"; _begin = begin; _end = end; _step = step; _chunk = chunk; _body = body; } @Override protected void compute() { if (_end <= _begin + _chunk * _step || getSurplusQueuedTaskCount() > NSQT) { for (int i = _begin; i < _end; i += _step) { _body.compute(i); } } else { int middle = middle(_begin, _end, _step); LoopIntAction l = new LoopIntAction(_begin, middle, _step, _chunk, _body); LoopIntAction r = (middle < _end) ? new LoopIntAction(middle, _end, _step, _chunk, _body) : null; if (r != null) r.fork(); l.compute(); if (r != null) r.join(); } } private final int _begin, _end, _step, _chunk; private final LoopInt _body; } /** * Fork-join task for parallel reduce. */ private static class ReduceIntTask extends RecursiveTask { ReduceIntTask(int begin, int end, int step, int chunk, ReduceInt body) { assert begin < end : "begin < end"; _begin = begin; _end = end; _step = step; _chunk = chunk; _body = body; } @Override protected V compute() { if (_end <= _begin + _chunk * _step || getSurplusQueuedTaskCount() > NSQT) { V v = _body.compute(_begin); for (int i = _begin + _step; i < _end; i += _step) { V vi = _body.compute(i); v = _body.combine(v, vi); } return v; } else { int middle = middle(_begin, _end, _step); ReduceIntTask l = new ReduceIntTask(_begin, middle, _step, _chunk, _body); ReduceIntTask r = (middle < _end) ? new ReduceIntTask( middle, _end, _step, _chunk, _body) : null; if (r != null) r.fork(); V v = l.compute(); if (r != null) v = _body.combine(v, r.join()); return v; } } private final int _begin, _end, _step, _chunk; private final ReduceInt _body; } } ================================================ FILE: test/corpus.LABEL ================================================ apple apple apple apple apple apple apple apple apple apple apple apple apple apple apple apple apple apple apple apple apple apple apple apple apple apple apple apple apple apple apple apple apple apple apple apple apple apple apple apple apple apple apple apple apple apple apple apple apple apple apple apple apple apple apple apple apple apple apple apple apple apple apple apple apple apple apple apple apple apple apple apple apple apple apple apple apple apple apple apple apple apple apple apple apple apple apple apple apple apple apple apple apple apple apple apple apple apple apple apple google google google google google google google google google google google google google google google google google google google google google google google google google google google google google google google google google google google google google google google google google google google google google google google google google google google google google google google google google google google google google google google google google google google google google google google google google google google google google google google google google google google google google google google google google google google google google google google google google google google google microsoft microsoft microsoft microsoft microsoft microsoft microsoft microsoft microsoft microsoft microsoft microsoft microsoft microsoft microsoft microsoft microsoft microsoft microsoft microsoft microsoft microsoft microsoft microsoft microsoft microsoft microsoft microsoft microsoft microsoft microsoft microsoft microsoft microsoft microsoft microsoft microsoft microsoft microsoft microsoft microsoft microsoft microsoft microsoft microsoft microsoft microsoft microsoft microsoft microsoft microsoft microsoft microsoft microsoft microsoft microsoft microsoft microsoft microsoft microsoft microsoft microsoft microsoft microsoft microsoft microsoft microsoft microsoft microsoft microsoft microsoft microsoft microsoft microsoft microsoft microsoft microsoft microsoft microsoft microsoft microsoft microsoft microsoft microsoft microsoft microsoft microsoft microsoft microsoft microsoft microsoft microsoft microsoft microsoft microsoft microsoft microsoft microsoft microsoft microsoft twitter twitter twitter twitter twitter twitter twitter twitter twitter twitter twitter twitter twitter twitter twitter twitter twitter twitter twitter twitter twitter twitter twitter twitter twitter twitter twitter twitter twitter twitter twitter twitter twitter twitter twitter twitter twitter twitter twitter twitter twitter twitter twitter twitter twitter twitter twitter twitter twitter twitter twitter twitter twitter twitter twitter twitter twitter twitter twitter twitter twitter twitter twitter twitter twitter twitter twitter twitter twitter twitter twitter twitter twitter twitter twitter twitter twitter twitter twitter twitter twitter twitter twitter twitter twitter twitter twitter twitter twitter twitter twitter twitter twitter twitter twitter twitter twitter twitter twitter twitter ================================================ FILE: test/corpus.txt ================================================ iphone crack iphone adding support iphone announced youtube video guy siri pretty love rim made easy switch iphone yeah realized ios current blackberry user bit disappointed move android iphone things siri sooo glad gave siri sense humor great personal event tonight store companies experience customer service apply job hope call lol lmao siri find hide body registered developer appreciated wow great deals ipad gen offers great deals gen ipads learning trip hong kong gotta hand iphones apps dark side hey send free iphone publicly burn blackberry find mac air macbook keyboard lunch break today warranty ipads replace siri amazing amazing ios feature reply featured education apps website today sweet reply useless days iphone yesterday awesome amount info question brother iphone people iphone phone happy ceo points ios bus iphone umber appstore itunes mobile devices talking desktop application bring ipad ipad set red red ipad sells million iphone weekend steve jobs lives iphone apologize downloads ios users lmfao argument siri incredible million iphone screenshot days iphone iphone fixed ios battery drain problem replacement iphone working brand macbook professional macbook years miss time siri dad mom brother girlfriend store amazing call waiting music sweet replaced bad sells million iphones debut weekend smartphone loving technology iphone mac air icloud technology loving ios update mention store great customer service store time iphone forward man longer paying texts girlfriend iphone great icloud set works cloud mommy totally email company great service store loving ios upgrade iphone ios ipad making switch android iphone iphone smartphone store incredible people offering water macbook professional wow macbook sick play man loving camera iphone facebook yeah ios changed life reader worldwide web love service case hand case years jobs iphone iphone iphone blackberry years lost service moving iphone sells million iphones days weekend iphone macbook professional year time selling android post card putting kind glad hear alive god youtube bad ass system loving days iphone nice gave ios email lock screen opening unlocking word wow iphone weekend sales top million love ios easter eggs pull middle top bottom pulls awesome feature ios love ios easter eggs pull middle top bottom pulls awesome feature run beautiful morning man love ios iphone simply made happy text lol text day great customer service received today phone phone loving ipod update upgraded iphone siri worth upgrade forward siri great world missed loving iphone ios love iphone great genius cards application card arrived local post office today iphone siri meet siri iphone click link work feel worst ios upgrade good luck blackberry loving ios awesome iphone addicted club guy playing facetime watching game bar blackberry boo powered technology work ipod time iphone good job guys siri year lead lost ios sweet notifications phone search covers mail wifi sync icloud backup integrated great james story today times retail success world due ios guys impressive service genius bar metro center power replaced free screen replacement free nice guy store replaced phone showed crack screen iphone battery longer day happened edge iphone nice job minutes write blackberry showing eye phone impressed iphone space amazing products people things making ipad feel ios nexus good feel bit guess android users android nice game helps search nice game helps search facebook build website website free android ics pretty good worth android ice cream sandwich nexus android nexus exciting day ice cream sandwich day android wow nexus beautiful totally gonna market share smart phone market integrated data usage manager brilliant design watching lol ice cream sandwich android works htc desire ice cream sandwich sounds android ice cream sandwich amazing imo android missing forget phone nice feature android nexus finally unveiled android ice cream sandwich good finally searches logged users rim strategy released hours release ics man love galaxy nexus samsung android doubt share winning war dear galaxy nexus send email technology telegraph reports biggest threat facebook power users samsung made bad android king facebook power users telegraph socialmedia impressed android update good font design video wallet wow tweet remember spell straight android samsung nexus efficient fun releases infinite digital bookcase pass social seo facebook ice cream sandwich stop carriers bullying smartphone users android agree freaking awesome icecream great helps samsung galaxy nexus iphone ice cream sandwich delicious iphone launches android aka loving samsung push mobile experience forward finally power volume screenshot ics nexus press conference slick high school appreciated scream scream scream android job major game mobile space thinking ahead venturebeat virtual bookcase sharing android phone keeping iphone android ice cream sandwich feature closer roboto type face read work samsung android ics impressive add profile webgl project add addthis work company work invention wait ice cream sandwich android stop nexus phone android device updated galaxy nexus android introducing ice cream sandwich delicious version android ics excited android features android ics wait nexus play check video introducing galaxy nexus simple beautiful smart youtube android nexus cream ice cream phone job great small businesses platform features thoughts loves presentations tool docs adding video brilliant webgl bookcase searches things android ice cream introducing galaxy nexus simple beautiful smart nexus prime android interesting bookcase venturebeat releases infinite digital bookcase good finally focus user experience android ics awesome phone android motorola iphone ice cream sandwich android nexus line smart move android beam alright made team team android android reply font good start ics ice cream sandwich face unlock works ready ice cream sandwich ics nexus android android ice cream sandwich android taste ice cream sandwich bite samsung event live blog gadget haven android android ice cream sandwich make smartphone operating systems photo sharing people application ice cream sandwich imo ics android nexus phone makes iphone cheap store android sweet ice cream sandwich android ice cream sandwich officially ics raise hand android powered phone samsung siri android device replace iphone nexus page live nexus android excited android beam face unlock android ics linkedin tools company page contact samsung ice cream sandwich samsung introducing galaxy nexus simple beautiful smart android ics samsung glad design android shows waiting thoughts android ics excited play features android register galaxy nexus android wow webgl infinite bookcase ics awesome wait face unlock android gotta pretty android chrome android november direct purchase samsung nexus wanna awesome event time change android samsung ios user ics awesome great job yeah great job ics literally mind blown samsung motorola verizon perfect opens door spanish entrepreneurs project intel ibm windows phone mango update process ahead schedule mango back smartphone rich word works computer free gen stores watch codename data explorer ctp coming lunch today vslive watch codename data explorer ctp coming month details search improvements windows start screen mango shows taste smartphone success mango awesome moving dev finally local stores offer free windows phone devices stores offer free windows phone devices neowin store spend hard vslive free west check hey parents free tools kids online live family cloud offers students free access improve tcn awesome bit details windows search improvements yeah taking metro yeah good android love kids tech explains improvements windows start screen search tech search idea search great bing king search search powerpoint users power create service bye solutions future information innovators nov info curate personal history project greenwich month beam research project great sql server session works days ballmer thinks computer scientist android tech agree great time win server works fine vmware wow tech turns body touchscreen psfk love love feeling building vslive bringing conference research shows awesome step closer bit kinect research shows science science fact cool sound research shows science science fact zune music canada music news kinect makes learning playful education mango check change world good world wait watching windows pretty impressive finally mac interesting battle store xbox share god blog post cool tool mouse tools forget siri beating speech commands mango siri tests proves appsense enterprise capability users personalization database enterprise software good points sap dynamics good dev secure anti impressed creating images mac blown marketing yahoo sale years back bought glad deal year omnitouch impressive technology good bing paying ipads windows tablets study home day great time mango shows taste smartphone success picture services cloud love windows net dev nice talk community omg sharepoint working innovation sad sad office love genius love gates foundation good skype family amazing things absolutely loving mouse fan cool video turn surface touchscreen wow android ics lots talk mango launch people public speaking updated computer windows ics android kill mango nokia people names mail week outlook mac sucks hate xbox accounts hack reports update net windows media center fail eclipsed word upgrade doc doc word won open doc suck u.s. antitrust leaving business played dumb lync crash issue mac fixed broke played engages racketeering calls respect nokia chief executive mole frozen xbox live xbl accounts online games report hacked gave windows dev preview good waiting beta windows powerpoint fix powerpoint presentations eclipsed guardian kind search great time family advertising windows forget past antitrust issues paying make racketeering day talking talk tomorrow waiting reader compares albatross neck agree join lot word freeze minutes lol perfect simple hate windows phones months months lose reader compares albatross neck agree join discussion make sleep plan feel world put facebook blackberry helps miss boo everytime leave back back telling lol application ass theme sleep sleep starting sending hashtags emails taking lives shit lol hell today introduced social media love facebook yeah shows glad pretty facebook gotta love shit round world speed bed gonna minute bed dear fucking missed today internet tweet keeping busy school good thing people left social social media guess addicted university exam questions good thing people left social side apples facebook content bed favorite application facebook facebook change makes excited privacy impressive numbers smm socialmedia fuck facebook bullshit bitch cool love fuck facebook follow haven shit man haven fun find song end television show watched literally back facebook text email technology good isn pretty damn amazing hope year fast dear missed promise touch bored sad mad happy true friend facebook sucks amp shit funny haven shit day voice people real life lol yeah time bug science hashtags facebook feeling real world biggest facebook messed make add reliable freaking kidding wth tomorrow blue ass bird continued dead emails telling sucks follow people reporting retweets working technical problem back lol retweets broken haven tuesday tomorrow blue ass bird ass ain showing current mentions tweets gonna problems fixed asap retweets man boring application show touch tweet trouble application updating application messed everytime text message show fucking retweets bitch sooo trash showing retweets shit mom argument pretty addicted care appreciated start working computer retweets section account working hours problem good send bloody tweets feel make account fucking late damn dear fix shit retweets mentions dead fuck point giving tweets tweeted past days lol messed followers numbers timeline mentions shit garbage hell television man stupid fucking give damn mentions ugh fucking facebook television wanna study show retweets ill back facebook reply opinions forget day time haven blogs tumblr talk step game reminder fail join follow ways competition people facebook day life drop follow show love telling reply sleep time emotions call night work break time yeah tumblr love age year days hours minutes seconds find wanna aye shit living shout favorite people happy girls follow back sleep good people trip ================================================ FILE: test/corpus_test.txt ================================================ making ipad feel ios nexus good feel bit guess android users android nice game helps search nice game helps search facebook build website website free android ics pretty good worth android ice cream sandwich nexus android nexus exciting day ice cream sandwich day android wow nexus beautiful totally gonna market share smart phone market integrated data usage manager brilliant design watching lol ice cream sandwich android works htc desire ice cream sandwich sounds android ice cream sandwich amazing imo android missing forget phone nice feature android nexus finally unveiled android ice cream sandwich good finally searches logged users rim strategy released hours release ics man love galaxy nexus samsung android doubt share winning war dear galaxy nexus send email technology telegraph reports biggest threat facebook power users samsung made bad android king facebook power users telegraph socialmedia impressed android update good font design video wallet wow tweet remember spell straight android samsung nexus efficient fun releases infinite digital bookcase pass social seo facebook ice cream sandwich stop carriers bullying smartphone users android agree freaking awesome icecream great helps samsung galaxy nexus iphone ice cream sandwich delicious iphone launches android aka loving samsung push mobile experience forward finally power volume screenshot ics nexus press conference slick high school appreciated scream scream scream android job major game mobile space thinking ahead venturebeat virtual bookcase sharing android phone keeping iphone android ice cream sandwich feature closer roboto type face read work samsung android ics impressive add profile webgl project add addthis work company work invention wait ice cream sandwich android stop nexus phone android device updated galaxy nexus android introducing ice cream sandwich delicious version android ics excited android features android ics wait nexus play check video introducing galaxy nexus simple beautiful smart youtube android nexus cream ice cream phone job great small businesses platform features thoughts loves presentations tool docs adding video brilliant webgl bookcase searches things android ice cream introducing galaxy nexus simple beautiful smart nexus prime android interesting bookcase venturebeat releases infinite digital bookcase good finally focus user experience android ics awesome phone android motorola iphone ice cream sandwich android nexus line smart move android beam alright made team team android android reply font good start ics ice cream sandwich face unlock works ready ice cream sandwich ics nexus android android ice cream sandwich android taste ice cream sandwich bite samsung event live blog gadget haven android android ice cream sandwich make smartphone operating systems photo sharing people application ice cream sandwich imo ics android nexus phone makes iphone cheap store android sweet ice cream sandwich android ice cream sandwich officially ics raise hand android powered phone samsung siri android device replace iphone nexus page live nexus android excited android beam face unlock android ics linkedin tools company page contact samsung ice cream sandwich samsung introducing galaxy nexus simple beautiful smart android ics samsung glad design android shows waiting thoughts android ics excited play features android register galaxy nexus android wow webgl infinite bookcase ics awesome wait face unlock android gotta pretty android chrome android november direct purchase samsung nexus wanna awesome event time change android samsung ios user ics awesome great job yeah great job ics literally mind blown samsung motorola verizon perfect opens door spanish entrepreneurs project intel ibm windows phone mango update process ahead schedule mango back smartphone rich word works computer free gen stores watch codename data explorer ctp coming lunch today vslive watch codename data explorer ctp coming month details search improvements windows start screen mango shows taste smartphone success mango awesome moving dev finally local stores offer free windows phone devices stores offer free windows phone devices neowin store spend hard vslive free west check hey parents free tools kids online live family cloud offers students free access improve tcn awesome bit details windows search improvements yeah taking metro yeah good android love kids tech explains improvements windows start screen search tech search idea search great bing king search search powerpoint users power create service bye solutions future information innovators nov info curate personal history project greenwich month beam research project great sql server session works days ballmer thinks computer scientist android tech agree great time win server works fine vmware wow tech turns body touchscreen psfk love love feeling building vslive bringing conference research shows awesome step closer bit kinect research shows science science fact cool sound research shows science science fact zune music canada music news kinect makes learning playful education mango check change world good world wait watching windows pretty impressive finally mac interesting battle store xbox share god blog post cool tool mouse tools forget siri beating speech commands mango siri tests proves appsense enterprise capability users personalization database enterprise software good points sap dynamics good dev secure anti impressed creating images mac blown marketing yahoo sale years back bought glad deal year omnitouch impressive technology good bing paying ipads windows tablets study home day great time mango shows taste smartphone success picture services cloud love windows net dev nice talk community omg sharepoint working innovation sad sad office love genius love gates foundation good skype family amazing things absolutely loving mouse fan cool video turn surface touchscreen wow android ics lots talk mango launch people public speaking updated computer windows ics android kill mango nokia people names mail week outlook mac sucks hate xbox accounts hack reports update net windows media center fail eclipsed word upgrade doc doc word won open doc suck u.s. antitrust leaving business played dumb lync crash issue mac fixed broke played engages racketeering calls respect nokia chief executive mole frozen xbox live xbl accounts online games report hacked gave windows dev preview good waiting beta windows powerpoint fix powerpoint presentations eclipsed guardian kind search great time family advertising windows forget past antitrust issues paying make racketeering day talking talk tomorrow waiting reader compares albatross neck agree join lot word freeze minutes lol perfect simple hate windows phones months months lose ================================================ FILE: test/wordVectors.txt ================================================ [File too large to display: 12.5 MB]