Repository: hexiangnan/sigir16-eals
Branch: master
Commit: 993d19f37f61
Files: 31
Total size: 18.1 MB
Directory structure:
gitextract_21_aynok/
├── .classpath
├── .project
├── README.md
├── data/
│ ├── README.md
│ └── yelp.rating
├── lib/
│ ├── happy.coding.utils.jar
│ └── json-simple.jar
└── src/
├── algorithms/
│ ├── ItemKNN.java
│ ├── ItemPopularity.java
│ ├── MF_ALS.java
│ ├── MF_CD.java
│ ├── MF_fastALS.java
│ ├── MFbpr.java
│ └── TopKRecommender.java
├── data_structure/
│ ├── DataMap.java
│ ├── DenseMatrix.java
│ ├── DenseVector.java
│ ├── Pair.java
│ ├── Rating.java
│ ├── SparseMatrix.java
│ └── SparseVector.java
├── main/
│ ├── main.java
│ ├── main_MF.java
│ ├── main_bpr.java
│ └── main_online.java
└── utils/
├── CommonUtils.java
├── DatasetUtil.java
├── Printer.java
├── SortMapExample.java
├── StopwordsFilter.java
└── TopKPriorityQueue.java
================================================
FILE CONTENTS
================================================
================================================
FILE: .classpath
================================================
================================================
FILE: .project
================================================
sigir16-eals
org.eclipse.jdt.core.javabuilder
org.eclipse.jdt.core.javanature
================================================
FILE: README.md
================================================
# sigir16-eals
Experiments codes for SIGIR'16 paper "Fast Matrix Factorization for Online Recommendation with Implicit Feedback "
================================================
FILE: data/README.md
================================================
The amazon dataset is too large to put here. Please email xiangnanhe@gmail.com to request the data if needed.
================================================
FILE: data/yelp.rating
================================================
[File too large to display: 17.9 MB]
================================================
FILE: src/algorithms/ItemKNN.java
================================================
package algorithms;
import java.util.ArrayList;
import java.util.HashMap;
import java.util.HashSet;
import java.util.Map;
import utils.CommonUtils;
import data_structure.Rating;
import data_structure.SparseMatrix;
import data_structure.SparseVector;
/**
* Implement ItemKNN method for topK recommendation, as described in:
* Collaborative filtering for implicit feedback datasets.
* By Yifan Hu , Yehuda Koren , Chris Volinsky.
* In IEEE ICDM'2008.
*
* @author xiangnanhe
*
*/
public class ItemKNN extends TopKRecommender {
/** Similarity matrix of item-item . */
public SparseMatrix similarity;
/** K neighbors to consider for each item */
private int K = 0;
/** Cache the L2 length for each item. */
double[] lengths;
public ItemKNN(SparseMatrix trainMatrix, ArrayList testRatings,
int topK, int threadNum, int K) {
super(trainMatrix, testRatings, topK, threadNum);
this.K = K;
this.similarity = new SparseMatrix(itemCount, itemCount);
}
public void buildModel() {
// The length cache
lengths = new double[itemCount];
for (int i = 0; i < itemCount; i ++) {
lengths[i] = Math.sqrt(trainMatrix.getColRef(i).squareSum());
}
// Run model multi-threads splitted by items.
ItemKNNThread[] threads = new ItemKNNThread[threadNum];
for (int t = 0; t < threadNum; t ++) {
int startItem = (itemCount / threadNum) * t;
int endItem = (t == threadNum-1) ? itemCount :
(itemCount / threadNum) * (t + 1);
threads[t] = new ItemKNNThread(this, startItem, endItem);
threads[t].start();
}
// Wait until all threads are finished.
for (int t = 0; t < threads.length; t++) {
try {
threads[t].join();
} catch (InterruptedException e) {
System.err.println("InterruptException was caught: " + e.getMessage());
}
}
}
protected void buildModel_items(int startItem, int endItem) {
// Build the similarity matrix for selected items.
for (int i = startItem; i < endItem; i ++) {
HashMap map_item_score = new HashMap();
for (int j = 0; j < itemCount & j != i; j ++) {
// Cosine similarity
double score = trainMatrix.getColRef(i).innerProduct(trainMatrix.getColRef(j));
if (score != 0) {
score /= (lengths[i] * lengths[j]);
map_item_score.put(j, score);
}
}
if (K <= 0) { // All neighbors
for (int j : map_item_score.keySet()) {
similarity.setValue(i, j, map_item_score.get(j));
}
} else { // Only K nearest neighbors
for (int j : CommonUtils.TopKeysByValue(map_item_score, K, null)) {
similarity.setValue(i, j, map_item_score.get(j));
}
} // end if
} // end for
}
public double predict(int u, int i) {
return trainMatrix.getRowRef(u).innerProduct(similarity.getRowRef(i));
}
@Override
public void updateModel(int u, int i) {
// TODO Implement SIGMOD15 paper
}
}
// Thread for building model for ItemKNN.
class ItemKNNThread extends Thread {
ItemKNN model;
int startItem;
int endItem;
public ItemKNNThread(ItemKNN model, int startItem, int endItem) {
this.model = model;
this.startItem = startItem;
this.endItem = endItem;
}
public void run() {
model.buildModel_items(startItem, endItem);
}
}
================================================
FILE: src/algorithms/ItemPopularity.java
================================================
package algorithms;
import java.util.ArrayList;
import java.util.HashMap;
import data_structure.Rating;
import data_structure.SparseMatrix;
public class ItemPopularity extends TopKRecommender {
double[] item_popularity;
public ItemPopularity(SparseMatrix trainMatrix, ArrayList testRatings,
int topK, int threadNum) {
super(trainMatrix, testRatings, topK, threadNum);
item_popularity = new double[itemCount];
}
public void buildModel() {
for (int i = 0; i < itemCount; i++) {
// Measure popularity by number of reviews received.
item_popularity[i] = trainMatrix.getColRef(i).itemCount();
}
}
public double predict(int u, int i) {
return item_popularity[i];
}
@Override
public void updateModel(int u, int i) {
trainMatrix.setValue(u, i, 1);
item_popularity[i] += 1;
}
}
================================================
FILE: src/algorithms/MF_ALS.java
================================================
package algorithms;
import data_structure.Rating;
import data_structure.SparseMatrix;
import data_structure.DenseVector;
import data_structure.DenseMatrix;
import data_structure.Pair;
import data_structure.SparseVector;
import happy.coding.math.Randoms;
import java.util.ArrayList;
import java.util.Collections;
import java.util.Random;
import utils.Printer;
/**
* ALS algorithm of the ICDM'09 paper:
* Yifan Hu etc. Collaborative Filtering for Implicit Feedback Datasets.
* @author xiangnanhe
*/
public class MF_ALS extends TopKRecommender {
/** Model priors to set. */
int factors = 10; // number of latent factors.
int maxIter = 100; // maximum iterations.
double w0 = 0.01; // weight for 0s
double reg = 0.01; // regularization parameters
double init_mean = 0; // Gaussian mean for init V
double init_stdev = 0.01; // Gaussian std-dev for init V
/** Model parameters to learn */
DenseMatrix U; // latent vectors for users
DenseMatrix V; // latent vectors for items
/** Caches */
DenseMatrix SU;
DenseMatrix SV;
boolean showProgress;
boolean showLoss;
public MF_ALS(SparseMatrix trainMatrix, ArrayList testRatings,
int topK, int threadNum, int factors, int maxIter, double w0, double reg,
double init_mean, double init_stdev, boolean showProgress, boolean showLoss) {
super(trainMatrix, testRatings, topK, threadNum);
this.factors = factors;
this.maxIter = maxIter;
this.w0 = w0 / itemCount;
this.reg = reg;
this.init_mean = init_mean;
this.init_stdev = init_stdev;
this.showProgress = showProgress;
this.showLoss = showLoss;
this.initialize();
}
//remove
public void setUV(DenseMatrix U, DenseMatrix V) {
this.U = U.clone();
this.V = V.clone();
SU = U.transpose().mult(U);
SV = V.transpose().mult(V);
}
private void initialize() {
U = new DenseMatrix(userCount, factors);
V = new DenseMatrix(itemCount, factors);
U.init(init_mean, init_stdev);
V.init(init_mean, init_stdev);
SU = U.transpose().mult(U);
SV = V.transpose().mult(V);
}
// Implement the ALS algorithm of the ICDM'09 paper
public void buildModel() {
System.out.println("Run for MF_ALS");
double loss_pre = Double.MAX_VALUE;
for (int iter = 0; iter < maxIter; iter ++) {
Long start = System.currentTimeMillis();
// Update user factors
for (int u = 0; u < userCount; u ++) {
update_user(u);
}
// Update item factors
for (int i = 0; i < itemCount; i ++) {
update_item(i);
}
// Show progress
if (showProgress)
showProgress(iter, start, testRatings);
// Show loss
if (showLoss)
loss_pre = showLoss(iter, start, loss_pre);
}
}
// Run model for one iteration
public void runOneIteration() {
// Update user latent vectors
for (int u = 0; u < userCount; u ++) {
update_user(u);
}
// Update item latent vectors
for (int i = 0; i < itemCount; i ++) {
update_item(i);
}
}
private void update_user(int u) {
ArrayList itemList = trainMatrix.getRowRef(u).indexList();
// Get matrix Au
DenseMatrix Au = SU.scale(w0);
for (int k1 = 0; k1 < factors; k1 ++) {
for (int k2 = 0; k2 < factors; k2 ++) {
for (int i : itemList)
Au.add(k1, k2, V.get(i, k1) * V.get(i, k2) * (1 - w0));
}
}
// Get vector du
DenseVector du = new DenseVector(factors);
for (int k = 0; k < factors; k ++) {
for (int i : itemList)
du.add(k, V.get(i, k) * trainMatrix.getValue(u, i));
}
// Matrix inversion to get the new embedding
for (int k = 0; k < factors; k ++) { // consider the regularizer
Au.add(k, k, reg);
}
DenseVector newVector = Au.inv().mult(du);
// Update the SU cache
for (int f = 0; f < factors; f ++) {
for (int k = 0; k <= f; k ++) {
double val = SU.get(f, k) - U.get(u, f) * U.get(u, k)
+ newVector.get(f) * newVector.get(k);
SU.set(f, k, val);
SU.set(k, f, val);
}
}
// Update parameters
for (int k = 0; k < factors; k ++) {
U.set(u, k, newVector.get(k));
}
}
private void update_item(int i) {
ArrayList userList = trainMatrix.getColRef(i).indexList();
// Get matrix Ai
DenseMatrix Ai = SV.scale(w0);
for (int k1 = 0; k1 < factors; k1 ++) {
for (int k2 = 0; k2 < factors; k2 ++) {
for (int u : userList)
Ai.add(k1, k2, U.get(u, k1) * U.get(u, k2) * (1 - w0));
}
}
// Get vector di
DenseVector di = new DenseVector(factors);
for (int k = 0; k < factors; k ++) {
for (int u : userList)
di.add(k, U.get(u, k) * trainMatrix.getValue(u, i));
}
// Matrix inversion to get the new embedding
for (int k = 0; k < factors; k ++) { // consider the regularizer
Ai.add(k, k, reg);
}
DenseVector newVector = Ai.inv().mult(di);
// Update the SV cache
for (int f = 0; f < factors; f ++) {
for (int k = 0; k <= f; k ++) {
double val = SV.get(f, k) - V.get(i, f) * V.get(i, k)
+ newVector.get(f) * newVector.get(k);
SV.set(f, k, val);
SV.set(k, f, val);
}
}
// Update parameters
for (int k = 0; k < factors; k ++) {
V.set(i, k, newVector.get(k));
}
}
public double showLoss(int iter, long start, double loss_pre) {
long start1 = System.currentTimeMillis();
double loss_cur = loss();
String symbol = loss_pre >= loss_cur ? "-" : "+";
System.out.printf("Iter=%d [%s]\t [%s]loss: %.4f [%s]\n", iter,
Printer.printTime(start1 - start), symbol, loss_cur,
Printer.printTime(System.currentTimeMillis() - start1));
return loss_cur;
}
// Fast way to calculate the loss function
public double loss() {
// Init the SV cache for fast calculation
DenseMatrix SV = new DenseMatrix(factors, factors);
for (int f = 0; f < factors; f ++) {
for (int k = 0; k <= f; k ++) {
double val = 0;
for (int i = 0; i < itemCount; i ++)
val += V.get(i, f) * V.get(i, k);
SV.set(f, k, val);
SV.set(k, f, val);
}
}
double L = reg * (U.squaredSum() + V.squaredSum());
for (int u = 0; u < userCount; u ++) {
double l = 0;
for (int i : trainMatrix.getRowRef(u).indexList()) {
l += Math.pow(trainMatrix.getValue(u, i) - predict(u, i), 2);
}
l *= (1 - w0);
l += w0 * SV.mult(U.row(u, false)).inner(U.row(u, false));
L += l;
}
return L;
}
@Override
public double predict(int u, int i) {
return U.row(u, false).inner(V.row(i, false));
}
@Override
public void updateModel(int u, int i) {
trainMatrix.setValue(u, i, 1);
for (int iter = 0; iter < maxIterOnline; iter ++) {
update_user(u);
update_item(i);
}
}
}
================================================
FILE: src/algorithms/MF_CD.java
================================================
package algorithms;
import data_structure.Rating;
import data_structure.SparseMatrix;
import data_structure.DenseVector;
import data_structure.DenseMatrix;
import data_structure.Pair;
import data_structure.SparseVector;
import happy.coding.math.Randoms;
import java.util.ArrayList;
import java.util.Collections;
import java.util.Random;
import utils.Printer;
/**
* Coordinate descent algorithm of the KDD'15 paper:
* Robin Devooght etc. Dynamic Matrix Factorization with Priors on Unknown Values.
* @author xiangnanhe
*/
public class MF_CD extends TopKRecommender {
/** Model priors to set. */
int factors = 10; // number of latent factors.
int maxIter = 100; // maximum iterations.
double w0 = 0.01; // weight for 0s
double reg = 0.01; // regularization parameters
double init_mean = 0; // Gaussian mean for init V
double init_stdev = 0.01; // Gaussian std-dev for init V
/** Priors for line search */
int LSMaxIter = 10; // max iteration of the line search. Default is 10
double Alpha = 0.3; // parameter of line search. In the range (0, 0.5).
double Beta = 0.3; // parameter of line search. In the range (0, 1.0).
/** Model parameters to learn */
public DenseMatrix U; // latent vectors for users
public DenseMatrix V; // latent vectors for items
/** Caches */
DenseMatrix SU;
DenseMatrix SV;
boolean showProgress;
boolean showLoss;
// weight for each positive instance in trainMatrix
SparseMatrix W;
// weight of new instance in online learning
public double w_new = 1;
public MF_CD(SparseMatrix trainMatrix, ArrayList testRatings,
int topK, int threadNum, int factors, int maxIter, double w0, double reg,
double init_mean, double init_stdev, boolean showProgress, boolean showLoss) {
super(trainMatrix, testRatings, topK, threadNum);
this.factors = factors;
this.maxIter = maxIter;
this.w0 = w0 / itemCount;
this.reg = reg;
this.init_mean = init_mean;
this.init_stdev = init_stdev;
this.showProgress = showProgress;
this.showLoss = showLoss;
this.initialize();
// By default, the weight for positive instance is uniformly 1.
W = new SparseMatrix(userCount, itemCount);
for (int u = 0; u < userCount; u ++)
for (int i : trainMatrix.getRowRef(u).indexList())
W.setValue(u, i, 1);
}
private void initialize() {
U = new DenseMatrix(userCount, factors);
V = new DenseMatrix(itemCount, factors);
U.init(init_mean, init_stdev);
V.init(init_mean, init_stdev);
SU = U.transpose().mult(U);
SV = V.transpose().mult(V);
}
public void setTrain(SparseMatrix trainMatrix) {
this.trainMatrix = new SparseMatrix(trainMatrix);
W = new SparseMatrix(userCount, itemCount);
for (int u = 0; u < userCount; u ++)
for (int i : this.trainMatrix.getRowRef(u).indexList())
W.setValue(u, i, 1);
}
public void setLSpriors(int LSMaxIter, double Alpha, double Beta) {
this.LSMaxIter = LSMaxIter;
this.Alpha = Alpha;
this.Beta = Beta;
}
// remove
public void setUV(DenseMatrix U, DenseMatrix V) {
this.U = U.clone();
this.V = V.clone();
SU = U.transpose().mult(U);
SV = V.transpose().mult(V);
}
/**
* Implement the CD algorithm of the KDD'15 papers
*/
public void buildModel() {
//System.out.println("Run for MF_CD.");
ArrayList shuffle_list = new ArrayList();
for (int i = 0; i < itemCount + userCount; i ++)
shuffle_list.add(i);
double loss_pre = Double.MAX_VALUE;
for (int iter = 0; iter < maxIter; iter ++) {
Long start = System.currentTimeMillis();
Collections.shuffle(shuffle_list);
for (int index : shuffle_list) {
if (index >= userCount) // for an item
update_item(index - userCount);
else // for a user
update_user(index);
}
// Show progress
if (showProgress)
showProgress(iter, start, testRatings);
// Show loss
if (showLoss)
loss_pre = showLoss(iter, start, loss_pre);
} // end for iter
}
// Run model for one iteration
public void runOneIteration() {
ArrayList shuffle_list = new ArrayList();
for (int i = 0; i < itemCount + userCount; i ++)
shuffle_list.add(i);
Collections.shuffle(shuffle_list);
for (int index : shuffle_list) {
if (index >= userCount) // for an item
update_item(index - userCount);
else // for a user
update_user(index);
}
}
// Line search (book, Convex Optimization) for the best step size.
private double linesearch(int index, DenseVector embedding,
DenseVector gradient, int LSMaxIter, double Alpha, double Beta) {
double step_size = 1.0;
double init_error = error_row(index, embedding);
for (int iter = 0; iter < LSMaxIter; iter ++) {
// Build new features (ie embedding) with current step size
DenseVector newEmbedding = embedding.minus(gradient.scale(step_size));
// Check if new features are good enough. If not reduce step size
double new_error = error_row(index, newEmbedding);
if (new_error > init_error - Alpha * step_size * gradient.squaredSum())
step_size *= Beta;
else
break;
// Too many iterations, return step_size = 0
if (iter == LSMaxIter - 1) {
step_size = 0;
break;
}
}
return step_size;
}
private double error_row(int index, DenseVector embedding) {
double err = 0;
if (index >= userCount) { // for an item
int i = index - userCount;
for (int u : trainMatrix.getColRef(i).indexList()) {
double prediction = U.row(u, false).inner(embedding);
err += W.getValue(u, i) * Math.pow(trainMatrix.getValue(u, i) - prediction, 2);
}
err *= (1 - w0);
err += w0 * SU.mult(embedding).inner(embedding);
err += reg * embedding.squaredSum();
} else { // for a user
int u = index;
for (int i : trainMatrix.getRowRef(u).indexList()) {
double prediction = V.row(i, false).inner(embedding);
err += W.getValue(u, i) * Math.pow(trainMatrix.getValue(u, i) - prediction, 2);
}
err *= (1 - w0);
err += w0 * SV.mult(embedding).inner(embedding);
err += reg * embedding.squaredSum();
}
return err;
}
private void update_user(int u) {
DenseVector embedding = U.row(u, false);
// Calculate the gradient
DenseVector gradient = SV.mult(embedding).scale(w0);
for (int i : trainMatrix.getRowRef(u).indexList()) {
double mul = W.getValue(u, i) * (predict(u, i) * (1 - w0) - trainMatrix.getValue(u, i));
gradient.selfAdd(V.row(i, false).scale(mul));
}
gradient.selfAdd(embedding.scale(reg)); // with regularizer
// Line search for learning rate
double lr = linesearch(u, embedding, gradient, LSMaxIter, Alpha, Beta);
// Update S cache before updating parameters
DenseVector new_embedding = embedding.minus(gradient.scale(lr));
for (int f = 0; f < factors; f ++) {
for (int k = 0; k <= f; k ++) {
double val = SU.get(f, k) - embedding.get(f) * embedding.get(k)
+ new_embedding.get(f) * new_embedding.get(k);
SU.set(f, k, val);
SU.set(k, f, val);
}
}
// Parameter update
for (int f = 0; f < factors; f ++)
embedding.set(f, new_embedding.get(f));
}
private void update_item(int i) {
DenseVector embedding = V.row(i, false);
// Calculate the gradient
DenseVector gradient = SU.mult(embedding).scale(w0);
for (int u : trainMatrix.getColRef(i).indexList()) {
double mul = W.getValue(u, i) * (predict(u, i) * (1 - w0) - trainMatrix.getValue(u, i));
gradient.selfAdd(U.row(u, false).scale(mul));
}
gradient.selfAdd(embedding.scale(reg)); // with regularizer
// Line search for learning rate
double lr = linesearch(userCount + i, embedding, gradient, LSMaxIter, Alpha, Beta);
// Update SV cache
DenseVector new_embedding = embedding.minus(gradient.scale(lr));
for (int f = 0; f < factors; f ++) {
for (int k = 0; k <= f; k ++) {
double val = SV.get(f, k) - embedding.get(f) * embedding.get(k)
+ new_embedding.get(f) * new_embedding.get(k);
SV.set(f, k, val);
SV.set(k, f, val);
}
}
// Parameter update
for (int f = 0; f < factors; f ++)
embedding.set(f, new_embedding.get(f));
}
public double showLoss(int iter, long start, double loss_pre) {
long start1 = System.currentTimeMillis();
double loss_cur = loss();
String symbol = loss_pre >= loss_cur ? "-" : "+";
System.out.printf("Iter=%d [%s]\t [%s]loss: %.4f [%s]\n", iter,
Printer.printTime(start1 - start), symbol, loss_cur,
Printer.printTime(System.currentTimeMillis() - start1));
return loss_cur;
}
// Fast way to calculate the loss function
public double loss() {
double L = reg * (U.squaredSum() + V.squaredSum());
for (int u = 0; u < userCount; u ++) {
double l = 0;
for (int i : trainMatrix.getRowRef(u).indexList()) {
l += W.getValue(u, i) * Math.pow(trainMatrix.getValue(u, i) - predict(u, i), 2);
}
l *= (1 - w0);
l += w0 * SV.mult(U.row(u, false)).inner(U.row(u, false));
L += l;
}
return L;
}
@Override
public double predict(int u, int i) {
return U.row(u, false).inner(V.row(i, false));
}
@Override
public void updateModel(int u, int i) {
trainMatrix.setValue(u, i, 1);
W.setValue(u, i, w_new);
for (int iter = 0; iter < maxIterOnline; iter ++) {
update_user(u);
update_item(i);
}
}
}
================================================
FILE: src/algorithms/MF_fastALS.java
================================================
package algorithms;
import data_structure.Rating;
import data_structure.SparseMatrix;
import data_structure.DenseVector;
import data_structure.DenseMatrix;
import data_structure.Pair;
import data_structure.SparseVector;
import happy.coding.math.Randoms;
import java.util.ArrayList;
import java.util.Collections;
import java.util.Random;
import java.util.HashMap;
import utils.Printer;
/**
* Fast ALS for weighted matrix factorization (with imputation)
* @author xiangnanhe
*/
public class MF_fastALS extends TopKRecommender {
/** Model priors to set. */
int factors = 10; // number of latent factors.
int maxIter = 500; // maximum iterations.
double reg = 0.01; // regularization parameters
double w0 = 1;
double init_mean = 0; // Gaussian mean for init V
double init_stdev = 0.01; // Gaussian std-dev for init V
/** Model parameters to learn */
public DenseMatrix U; // latent vectors for users
public DenseMatrix V; // latent vectors for items
/** Caches */
DenseMatrix SU;
DenseMatrix SV;
double[] prediction_users, prediction_items;
double[] rating_users, rating_items;
double[] w_users, w_items;
boolean showProgress;
boolean showLoss;
// weight for each positive instance in trainMatrix
SparseMatrix W;
// weight for negative instances on item i.
double[] Wi;
// weight of new instance in online learning
public double w_new = 1;
public MF_fastALS(SparseMatrix trainMatrix, ArrayList testRatings,
int topK, int threadNum, int factors, int maxIter, double w0, double alpha, double reg,
double init_mean, double init_stdev, boolean showProgress, boolean showLoss) {
super(trainMatrix, testRatings, topK, threadNum);
this.factors = factors;
this.maxIter = maxIter;
this.w0 = w0;
this.reg = reg;
this.init_mean = init_mean;
this.init_stdev = init_stdev;
this.showLoss = showLoss;
this.showProgress = showProgress;
// Set the Wi as a decay function w0 * pi ^ alpha
double sum = 0, Z = 0;
double[] p = new double[itemCount];
for (int i = 0; i < itemCount; i ++) {
p[i] = trainMatrix.getColRef(i).itemCount();
sum += p[i];
}
// convert p[i] to probability
for (int i = 0; i < itemCount; i ++) {
p[i] /= sum;
p[i] = Math.pow(p[i], alpha);
Z += p[i];
}
// assign weight
Wi = new double[itemCount];
for (int i = 0; i < itemCount; i ++)
Wi[i] = w0 * p[i] / Z;
// By default, the weight for positive instance is uniformly 1.
W = new SparseMatrix(userCount, itemCount);
for (int u = 0; u < userCount; u ++)
for (int i : trainMatrix.getRowRef(u).indexList())
W.setValue(u, i, 1);
// Init caches
prediction_users = new double[userCount];
prediction_items = new double[itemCount];
rating_users = new double[userCount];
rating_items = new double[itemCount];
w_users = new double[userCount];
w_items = new double[itemCount];
// Init model parameters
U = new DenseMatrix(userCount, factors);
V = new DenseMatrix(itemCount, factors);
U.init(init_mean, init_stdev);
V.init(init_mean, init_stdev);
initS();
}
public void setTrain(SparseMatrix trainMatrix) {
this.trainMatrix = new SparseMatrix(trainMatrix);
W = new SparseMatrix(userCount, itemCount);
for (int u = 0; u < userCount; u ++)
for (int i : this.trainMatrix.getRowRef(u).indexList())
W.setValue(u, i, 1);
}
// Init SU and SV
private void initS() {
SU = U.transpose().mult(U);
// Init SV as V^T Wi V
SV = new DenseMatrix(factors, factors);
for (int f = 0; f < factors; f ++) {
for (int k = 0; k <= f; k ++) {
double val = 0;
for (int i = 0; i < itemCount; i ++)
val += V.get(i, f) * V.get(i, k) * Wi[i];
SV.set(f, k, val);
SV.set(k, f, val);
}
}
}
//remove
public void setUV(DenseMatrix U, DenseMatrix V) {
this.U = U.clone();
this.V = V.clone();
initS();
}
public void buildModel() {
//System.out.println("Run for FastALS. ");
double loss_pre = Double.MAX_VALUE;
for (int iter = 0; iter < maxIter; iter ++) {
Long start = System.currentTimeMillis();
// Update user latent vectors
for (int u = 0; u < userCount; u ++) {
update_user(u);
}
// Update item latent vectors
for (int i = 0; i < itemCount; i ++) {
update_item(i);
}
// Show progress
if (showProgress)
showProgress(iter, start, testRatings);
// Show loss
if (showLoss)
loss_pre = showLoss(iter, start, loss_pre);
} // end for iter
}
// Run model for one iteration
public void runOneIteration() {
// Update user latent vectors
for (int u = 0; u < userCount; u ++) {
update_user(u);
}
// Update item latent vectors
for (int i = 0; i < itemCount; i ++) {
update_item(i);
}
}
protected void update_user(int u) {
ArrayList itemList = trainMatrix.getRowRef(u).indexList();
if (itemList.size() == 0) return; // user has no ratings
// prediction cache for the user
for (int i : itemList) {
prediction_items[i] = predict(u, i);
rating_items[i] = trainMatrix.getValue(u, i);
w_items[i] = W.getValue(u, i);
}
DenseVector oldVector = U.row(u);
for (int f = 0; f < factors; f ++) {
double numer = 0, denom = 0;
// O(K) complexity for the negative part
for (int k = 0; k < factors; k ++) {
if (k != f)
numer -= U.get(u, k) * SV.get(f, k);
}
//numer *= w0;
// O(Nu) complexity for the positive part
for (int i : itemList) {
prediction_items[i] -= U.get(u, f) * V.get(i, f);
numer += (w_items[i]*rating_items[i] - (w_items[i]-Wi[i]) * prediction_items[i]) * V.get(i, f);
denom += (w_items[i]-Wi[i]) * V.get(i, f) * V.get(i, f);
}
denom += SV.get(f, f) + reg;
// Parameter Update
U.set(u, f, numer / denom);
// Update the prediction cache
for (int i : itemList)
prediction_items[i] += U.get(u, f) * V.get(i, f);
} // end for f
// Update the SU cache
for (int f = 0; f < factors; f ++) {
for (int k = 0; k <= f; k ++) {
double val = SU.get(f, k) - oldVector.get(f) * oldVector.get(k)
+ U.get(u, f) * U.get(u, k);
SU.set(f, k, val);
SU.set(k, f, val);
}
} // end for f
}
protected void update_item(int i) {
ArrayList userList = trainMatrix.getColRef(i).indexList();
if (userList.size() == 0) return; // item has no ratings.
// prediction cache for the item
for (int u : userList) {
prediction_users[u] = predict(u, i);
rating_users[u] = trainMatrix.getValue(u, i);
w_users[u] = W.getValue(u, i);
}
DenseVector oldVector = V.row(i);
for (int f = 0; f < factors; f++) {
// O(K) complexity for the w0 part
double numer = 0, denom = 0;
for (int k = 0; k < factors; k ++) {
if (k != f)
numer -= V.get(i, k) * SU.get(f, k);
}
numer *= Wi[i];
// O(Ni) complexity for the positive ratings part
for (int u : userList) {
prediction_users[u] -= U.get(u, f) * V.get(i, f);
numer += (w_users[u]*rating_users[u] - (w_users[u]-Wi[i]) * prediction_users[u]) * U.get(u, f);
denom += (w_users[u]-Wi[i]) * U.get(u, f) * U.get(u, f);
}
denom += Wi[i] * SU.get(f, f) + reg;
// Parameter update
V.set(i, f, numer / denom);
// Update the prediction cache for the item
for (int u : userList)
prediction_users[u] += U.get(u, f) * V.get(i, f);
} // end for f
// Update the SV cache
for (int f = 0; f < factors; f ++) {
for (int k = 0; k <= f; k ++) {
double val = SV.get(f, k) - oldVector.get(f) * oldVector.get(k) * Wi[i]
+ V.get(i, f) * V.get(i, k) * Wi[i];
SV.set(f, k, val);
SV.set(k, f, val);
}
}
}
public double showLoss(int iter, long start, double loss_pre) {
long start1 = System.currentTimeMillis();
double loss_cur = loss();
String symbol = loss_pre >= loss_cur ? "-" : "+";
System.out.printf("Iter=%d [%s]\t [%s]loss: %.4f [%s]\n", iter,
Printer.printTime(start1 - start), symbol, loss_cur,
Printer.printTime(System.currentTimeMillis() - start1));
return loss_cur;
}
// Fast way to calculate the loss function
public double loss() {
double L = reg * (U.squaredSum() + V.squaredSum());
for (int u = 0; u < userCount; u ++) {
double l = 0;
for (int i : trainMatrix.getRowRef(u).indexList()) {
double pred = predict(u, i);
l += W.getValue(u, i) * Math.pow(trainMatrix.getValue(u, i) - pred, 2);
l -= Wi[i] * Math.pow(pred, 2);
}
l += SV.mult(U.row(u, false)).inner(U.row(u, false));
L += l;
}
return L;
}
@Override
public double predict(int u, int i) {
return U.row(u, false).inner(V.row(i, false));
}
@Override
public void updateModel(int u, int i) {
trainMatrix.setValue(u, i, 1);
W.setValue(u, i, w_new);
if (Wi[i] == 0) { // an new item
Wi[i] = w0 / itemCount;
// Update the SV cache
for (int f = 0; f < factors; f ++) {
for (int k = 0; k <= f; k ++) {
double val = SV.get(f, k) + V.get(i, f) * V.get(i, k) * Wi[i];
SV.set(f, k, val);
SV.set(k, f, val);
}
}
}
for (int iter = 0; iter < maxIterOnline; iter ++) {
update_user(u);
update_item(i);
}
}
/* // Raw way to calculate the loss function
public double loss() {
double L = reg * (U.squaredSum() + V.squaredSum());
for (int u = 0; u < userCount; u ++) {
double l = 0;
for (int i : trainMatrix.getRowRef(u).indexList()) {
l += Math.pow(trainMatrix.getValue(u, i) - predict(u, i), 2);
}
l *= (1 - w0);
for (int i = 0; i < itemCount; i ++) {
l += w0 * Math.pow(predict(u, i), 2);
}
L += l;
}
return L;
} */
}
================================================
FILE: src/algorithms/MFbpr.java
================================================
package algorithms;
import data_structure.Rating;
import data_structure.SparseMatrix;
import data_structure.DenseVector;
import data_structure.DenseMatrix;
import data_structure.Pair;
import java.util.ArrayList;
import java.util.Collections;
import java.util.Random;
import utils.Printer;
/**
* Implement the standard matrix factorization model, optimized by BPR loss.
* Rendle, Steffen, et al. "BPR: Bayesian personalized ranking from implicit feedback."
* Proc. of UAI 2009.
*
* Adaptive learning rate see the KDD'11 paper
* Large-Scale Matrix Factorization with Distributed Stochastic Gradient Descent
* @author xiangnanhe
*
*/
public class MFbpr extends TopKRecommender {
/** Model priors to set. */
int factors = 10; // number of latent factors.
int maxIter = 100; // maximum iterations.
double lr = 0.01; // Learning rate
boolean adaptive = false; // Whether to use adaptive learning rate
double reg = 0.01; // regularization parameters
double init_mean = 0; // Gaussian mean for init V
double init_stdev = 0.1; // Gaussian std-dev for init V
// Dynamic Negative Sampling [Zhang et al. SIGIR 2013]: sample X negatives and use the one with maximum predicted value as the true negative.
double num_dns = 1; // number of dynamic negative samples.
/** Model parameters to learn */
public DenseMatrix U; // latent vectors for users
public DenseMatrix V; // latent vectors for items
boolean showProgress;
public String onlineMode = "u";
Random rand = new Random();
public MFbpr(SparseMatrix trainMatrix, ArrayList testRatings,
int topK, int threadNum, int factors, int maxIter, double lr, boolean adaptive, double reg,
double init_mean, double init_stdev, int num_dns, boolean showProgress) {
super(trainMatrix, testRatings, topK, threadNum);
this.factors = factors;
this.maxIter = maxIter;
this.lr = lr;
this.adaptive = adaptive;
this.reg = reg;
this.init_mean = init_mean;
this.init_stdev = init_stdev;
this.num_dns = num_dns;
this.showProgress = showProgress;
// Init model parameters
U = new DenseMatrix(userCount, factors);
V = new DenseMatrix(itemCount, factors);
U.init(init_mean, init_stdev);
V.init(init_mean, init_stdev);
}
//remove
public void setUV(DenseMatrix U, DenseMatrix V) {
this.U = U.clone();
this.V = V.clone();
}
public void buildModel() {
int nonzeros = trainMatrix.itemCount();
double hr_prev = 0;
for (int iter = 0; iter < maxIter; iter ++) {
Long start = System.currentTimeMillis();
rand = new Random();
// Each training epoch
for (int s = 0; s < nonzeros; s ++) {
// sample a user
int u = rand.nextInt(userCount);
ArrayList itemList = trainMatrix.getRowRef(u).indexList();
if (itemList.size() == 0) continue;
// sample a positive item
int i = itemList.get(rand.nextInt(itemList.size()));
// One SGD step update
update_ui(u, i);
}
// Show progress per 10 epochs
if (showProgress && iter%10 == 0)
showProgress(iter, start, testRatings);
// Adjust the learning rate
if (adaptive) {
if (!showProgress) evaluate(testRatings);
double hr = ndcgs.mean();
lr = hr > hr_prev ? lr * 1.05 : lr * 0.5;
hr_prev = hr;
}
} // end for iter
}
public void runOneIteration() {
int nonzeros = trainMatrix.itemCount();
rand = new Random();
// Each training epoch
for (int s = 0; s < nonzeros; s ++) {
// sample a user
int u = rand.nextInt(userCount);
ArrayList itemList = trainMatrix.getRowRef(u).indexList();
if (itemList.size() == 0) continue;
// sample a positive item
int i = itemList.get(rand.nextInt(itemList.size()));
// One SGD step update
update_ui(u, i);
}
}
//One SGD step for a positive instance.
private void update_ui(int u, int i) {
// Dynamic negative sampling
// sample a negative item
int s = rand.nextInt(itemCount);
while (trainMatrix.getValue(u, s) != 0) {
s = rand.nextInt(itemCount);
}
int j = s; // record the negative example with the largest predict value
for (int k = 1; k < this.num_dns; k ++) {
// sample another negative item
s = rand.nextInt(itemCount);
while (trainMatrix.getValue(u, s) != 0) {
s = rand.nextInt(itemCount);
}
if (predict(u, s) > predict(u, j)) {
j = s;
}
}
// BPR update rules
double y_pos = predict(u, i); // target value of positive instance
double y_neg = predict(u, j); // target value of negative instance
double mult = -partial_loss(y_pos - y_neg);
for (int f = 0; f < factors; f ++) {
double grad_u = V.get(i, f) - V.get(j, f);
U.add(u, f, -lr * (mult * grad_u + reg * U.get(u, f)));
double grad = U.get(u, f);
V.add(i, f, -lr * (mult * grad + reg * V.get(i, f)));
V.add(j, f, -lr * (-mult * grad + reg * V.get(j, f)));
}
}
@Override
public double predict(int u, int i) {
return U.row(u, false).inner(V.row(i, false));
}
// Partial of the ln sigmoid function used by BPR.
private double partial_loss(double x) {
double exp_x = Math.exp(-x);
return exp_x / (1 + exp_x);
}
// Implement the Recsys08 method: Steffen Rendle, Lars Schmidt-Thieme,
// "Online-Updating Regularized Kernel Matrix Factorization Models"
public void updateModel(int u, int item) {
trainMatrix.setValue(u, item, 1);
rand = new Random();
// user retrain
ArrayList itemList = trainMatrix.getRowRef(u).indexList();
for (int iter = 0; iter < maxIterOnline; iter ++) {
Collections.shuffle(itemList);
for (int s = 0; s < itemList.size(); s ++) {
// retrain for the user or for the (user, item) pair
int i = onlineMode.equalsIgnoreCase("u") ? itemList.get(s) : item;
// One SGD step update
update_ui(u, i);
}
}
}
}
================================================
FILE: src/algorithms/TopKRecommender.java
================================================
package algorithms;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.HashMap;
import java.util.HashSet;
import java.util.List;
import utils.CommonUtils;
import utils.Printer;
import data_structure.DenseVector;
import data_structure.Rating;
import data_structure.SparseMatrix;
import data_structure.DenseMatrix;
import utils.TopKPriorityQueue;
import java.util.Map;
/**
* This is an abstract class for topK recommender systems.
* Define some variables to use, and member functions to implement by a topK recommender.
*
* @author HeXiangnan
* @since 2014.12.03
*/
public abstract class TopKRecommender {
/** The number of users. */
public int userCount;
/** The number of items. */
public int itemCount;
/** Rating matrix of training set. Users by Items.*/
public SparseMatrix trainMatrix;
/** Test ratings. For showing progress only. */
public ArrayList testRatings;
/** Position to cutoff. */
public int topK = 100;
/** Number of threads to run the model (if multi-thread implementation).*/
public int threadNum = 1;
/** Evaluation for each user (offline eval) or test instance (online eval).*/
public DenseVector hits;
public DenseVector ndcgs;
public DenseVector precs;
public int maxIterOnline = 1;
public boolean ignoreTrain = false; // ignore train items when generating topK list
public TopKRecommender() {};
public TopKRecommender(SparseMatrix trainMatrix,
ArrayList testRatings, int topK, int threadNum) {
this.trainMatrix = new SparseMatrix(trainMatrix);
this.testRatings = new ArrayList(testRatings);
this.topK = topK;
this.threadNum = threadNum;
this.userCount = trainMatrix.length()[0];
this.itemCount = trainMatrix.length()[1];
}
/**
* Get the prediction score of user u on item i. To be overridden.
*/
public abstract double predict(int u, int i);
/**
* Build the model.
*/
public abstract void buildModel();
/**
* Update the model with a new observation.
*/
public abstract void updateModel(int u, int i);
/**
* Show progress (evaluation) with current model parameters.
* @iter Current iteration
* @start Starting time of the iteration
* @testMatrix For evaluation purpose
*/
public void showProgress(int iter, long start, ArrayList testRatings) {
long end_iter = System.currentTimeMillis();
if (userCount == testRatings.size()) // leave-1-out eval
evaluate(testRatings);
else // global split
evaluateOnline(testRatings, 100);
long end_eval = System.currentTimeMillis();
System.out.printf("Iter=%d[%s] :\t %.4f\t %.4f\t %.4f\t %.4f\t [%s]\n",
iter, Printer.printTime(end_iter - start), loss(),
hits.mean(), ndcgs.mean(), precs.mean(), Printer.printTime(end_eval - end_iter));
}
/**
* Online evaluation (global split) by simulating the testing stream.
* @param ratings Test ratings that are sorted by time (old -> recent).
* @param interval Print evaluation result per X iteration.
*/
public void evaluateOnline(ArrayList testRatings, int interval) {
int testCount = testRatings.size();
hits = new DenseVector(testCount);
ndcgs = new DenseVector(testCount);
precs = new DenseVector(testCount);
// break down the results by number of user ratings of the test pair
int intervals = 10;
int[] counts = new int[intervals + 1];
double[] hits_r = new double[intervals + 1];
double[] ndcgs_r = new double[intervals + 1];
double[] precs_r = new double[intervals + 1];
Long updateTime = (long) 0;
for (int i = 0; i < testCount; i ++) {
// Check performance per interval:
if (i > 0 && interval > 0 && i % interval == 0) {
System.out.printf("%d:
=\t %.4f\t %.4f\t %.4f\n",
i, hits.sum() / i, ndcgs.sum() / i, precs.sum() / i);
}
// Evaluate model of the current test rating:
Rating rating = testRatings.get(i);
double[] res = this.evaluate_for_user(rating.userId, rating.itemId);
hits.set(i, res[0]);
ndcgs.set(i, res[1]);
precs.set(i, res[2]);
// statisitcs for break down
int r = trainMatrix.getRowRef(rating.userId).itemCount();
r = r> intervals ? intervals : r;
counts[r] += 1;
hits_r[r] += res[0];
ndcgs_r[r] += res[1];
precs_r[r] += res[2];
// Update the model
Long start = System.currentTimeMillis();
updateModel(rating.userId, rating.itemId);
updateTime += (System.currentTimeMillis() - start);
}
System.out.println("Break down the results by number of user ratings for the test pair.");
System.out.printf("#Rating\t Percentage\t HR\t NDCG\t MAP\n");
for (int i = 0; i <= intervals; i ++) {
System.out.printf("%d\t %.2f%%\t %.4f\t %.4f\t %.4f \n",
i, (double)counts[i] / testCount * 100,
hits_r[i] / counts[i], ndcgs_r[i] / counts[i], precs_r[i] / counts[i]);
}
System.out.printf("Avg model update time per instance: %.2f ms\n", (float)updateTime/testCount);
}
protected ArrayList threadSplit(int total, int threadNum, int t) {
ArrayList res = new ArrayList();
int start = (total / threadNum) * t;
int end = (t == threadNum-1) ? total :
(total / threadNum) * (t + 1);
for (int i = start; i < end; i ++)
res.add(i);
return res;
}
/**
* Offline evaluation (leave-1-out) for each user.
* @param topK position to cutoff
* @param testMatrix
* @throws InterruptedException
*/
public void evaluate(ArrayList testRatings) {
assert userCount == testRatings.size();
for (int u = 0; u < userCount; u ++)
assert u == testRatings.get(u).userId;
hits = new DenseVector(userCount);
ndcgs = new DenseVector(userCount);
precs = new DenseVector(userCount);
// Run the evaluation multi-threads splitted by users
EvaluationThread[] threads = new EvaluationThread[threadNum];
for (int t = 0; t < threadNum; t ++) {
ArrayList users = threadSplit(userCount, threadNum, t);
threads[t] = new EvaluationThread(this, testRatings, users);
threads[t].start();
}
// Wait until all threads are finished.
for (int t = 0; t < threads.length; t++) {
try {
threads[t].join();
} catch (InterruptedException e) {
System.err.println("InterruptException was caught: " + e.getMessage());
}
}
}
/**
* Evaluation for a specific user with given GT item.
* @return:
* result[0]: hit ratio
* result[1]: ndcg
* result[2]: precision
*/
protected double[] evaluate_for_user(int u, int gtItem) {
double[] result = new double[3];
HashMap map_item_score = new HashMap();
// Get the score of the test item first.
double maxScore = predict(u, gtItem);
// Early stopping if there are topK items larger than maxScore.
int countLarger = 0;
for (int i = 0; i < itemCount; i++) {
double score = predict(u, i);
map_item_score.put(i, score);
if (score > maxScore) countLarger ++;
if (countLarger > topK) return result; // early stopping
}
// Selecting topK items (does not exclude train items).
ArrayList rankList = ignoreTrain ?
CommonUtils.TopKeysByValue(map_item_score, topK, trainMatrix.getRowRef(u).indexList()) :
CommonUtils.TopKeysByValue(map_item_score, topK, null);
result[0] = getHitRatio(rankList, gtItem);
result[1] = getNDCG(rankList, gtItem);
result[2] = getPrecision(rankList, gtItem);
return result;
}
/**
* Compute Hit Ratio.
* @param rankList A list of ranked item IDs
* @param gtItem The ground truth item.
* @return Hit ratio.
*/
public double getHitRatio(List rankList, int gtItem) {
for (int item : rankList) {
if (item == gtItem) return 1;
}
return 0;
}
/**
* Compute NDCG of a list of ranked items.
* See http://recsyswiki.com/wiki/Discounted_Cumulative_Gain
* @param rankList a list of ranked item IDs
* @param gtItem The ground truth item.
* @return NDCG.
*/
public double getNDCG(List rankList, int gtItem) {
for (int i = 0; i < rankList.size(); i++) {
int item = rankList.get(i);
if (item == gtItem)
return Math.log(2) / Math.log(i+2);
}
return 0;
}
public double getPrecision(List rankList, int gtItem) {
for (int i = 0; i < rankList.size(); i++) {
int item = rankList.get(i);
if (item == gtItem)
return 1.0 / (i + 1);
}
return 0;
}
// remove
public void runOneIteration() {}
// remove
public double loss() {return 0;}
// remove
public void setUV(DenseMatrix U, DenseMatrix V) {};
}
// Thread for running the offline evaluation.
class EvaluationThread extends Thread {
TopKRecommender model;
ArrayList testRatings;
ArrayList users;
public EvaluationThread(TopKRecommender model, ArrayList testRatings,
ArrayList users) {
this.model = model;
this.testRatings = testRatings;
this.users = users;
}
public void run() {
for (int u : users) {
double[] res = model.evaluate_for_user(u, testRatings.get(u).itemId);
model.hits.set(u, res[0]);
model.ndcgs.set(u, res[1]);
model.precs.set(u, res[2]);
}
}
}
================================================
FILE: src/data_structure/DataMap.java
================================================
package data_structure;
import java.util.HashMap;
import java.util.Iterator;
import java.io.Serializable;
/**
* This is a class implementing HashMap-based data map.
* This data structure is used for implementing sparse vector and matrix.
*
* @author Joonseok Lee
* @since 2012. 4. 20
* @version 1.1
*/
public class DataMap, Val> implements Iterable, Serializable {
private static final long serialVersionUID = 8001;
/** Key-value mapping structure */
private HashMap map;
/*========================================
* Constructors
*========================================*/
/** Basic constructor without specifying the capacity. */
public DataMap() {
map = new HashMap();
}
/**
* A constructor specifying the capacity.
* BE CAREFUL TO USE THIS! Never set the capacity too larger than actually needed.
* It will waste the memory space, reducing performance of your program.
*/
public DataMap(int capacity) {
map = new HashMap(capacity);
}
/*========================================
* Getter/Setter
*========================================*/
/**
* Get a data value by the given key.
*
* @param key The key to search.
* @return The data value associated with the given key.
*/
public Val get(Key key) {
return map.get(key);
}
/**
* Set a data value with the given key.
*
* @param key The key to set.
* @param value The data value associated with the given key.
*/
public void put(Key key, Val value) {
if (value == null) {
map.remove(key);
}
else {
map.put(key, value);
}
}
/**
* Remove a data element with the given key.
*
* @param key The key to remove.
* @return The data value deleted with the given key.
*/
public Val remove(Key key) {
return map.remove(key);
}
/**
* Check whether the map has a specific key inside it.
*
* @param key The key to search.
* @return true if the map has the given key, false otherwise.
*/
public boolean contains(Key key) {
return map.containsKey(key);
}
/**
* Get an iterator for the map.
*
* @return The Iterator instance for the map.
*/
@Override
public Iterator iterator() {
return map.keySet().iterator();
}
/*========================================
* Properties
*========================================*/
/**
* Count the number of elements in the map.
*
* @return The number of items in the map.
*/
public int itemCount() {
return map.size();
}
}
================================================
FILE: src/data_structure/DenseMatrix.java
================================================
// Copyright (C) 2014 Guibing Guo
//
// This file is part of LibRec.
//
// LibRec is free software: you can redistribute it and/or modify
// it under the terms of the GNU General Public License as published by
// the Free Software Foundation, either version 3 of the License, or
// (at your option) any later version.
//
// LibRec is distributed in the hope that it will be useful,
// but WITHOUT ANY WARRANTY; without even the implied warranty of
// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
// GNU General Public License for more details.
//
// You should have received a copy of the GNU General Public License
// along with LibRec. If not, see .
//
package data_structure;
import happy.coding.io.Strings;
import happy.coding.math.Randoms;
import java.io.Serializable;
import java.util.Arrays;
/**
* Data Structure: dense matrix
*
* A big reason that we do not adopt original DenseMatrix from M4J libraray is
* because the latter using one-dimensional array to store data, which will
* often cause OutOfMemory exception due to the limit of maximum length of a
* one-dimensional Java array.
*
* @author guoguibing
*
*/
public class DenseMatrix implements Serializable {
private static final long serialVersionUID = -2069621030647530185L;
// dimension
protected int numRows, numColumns;
// read data
protected double[][] data;
/**
* Construct a dense matrix with specified dimensions
*
* @param numRows
* number of rows
* @param numColumns
* number of columns
*/
public DenseMatrix(int numRows, int numColumns) {
this.numRows = numRows;
this.numColumns = numColumns;
data = new double[numRows][numColumns];
}
/**
* Construct a dense matrix by copying data from a given 2D array
*
* @param array
* data array
*/
public DenseMatrix(double[][] array) {
this(array.length, array[0].length);
for (int i = 0; i < numRows; i++)
for (int j = 0; j < numColumns; j++)
data[i][j] = array[i][j];
}
/**
* Construct a dense matrix by copying data from a given matrix
*
* @param mat
* input matrix
*/
public DenseMatrix(DenseMatrix mat) {
this(mat.data);
}
/**
* Make a deep copy of current matrix
*/
public DenseMatrix clone() {
return new DenseMatrix(this);
}
/**
* Construct an identity matrix
*
* @param dim
* dimension
* @return an identity matrix
*/
public static DenseMatrix eye(int dim) {
DenseMatrix mat = new DenseMatrix(dim, dim);
for (int i = 0; i < mat.numRows; i++)
mat.set(i, i, 1.0);
return mat;
}
/**
* Initialize a dense matrix with small Guassian values
*
* NOTE: small initial values make it easier to train a
* model; otherwise a very small learning rate may be needed (especially
* when the number of factors is large) which can cause bad performance.
*/
public void init(double mean, double sigma) {
for (int i = 0; i < numRows; i++)
for (int j = 0; j < numColumns; j++)
data[i][j] = Randoms.gaussian(mean, sigma);
}
/**
* initialize a dense matrix with small random values in (0, range)
*/
public void init(double range) {
for (int i = 0; i < numRows; i++)
for (int j = 0; j < numColumns; j++)
data[i][j] = Randoms.uniform(0, range);
}
/**
* initialize a dense matrix with small random values in (0, 1)
*/
public void init() {
init(1.0);
}
/**
* @return number of rows
*/
public int numRows() {
return numRows;
}
/**
* @return number of columns
*/
public int numColumns() {
return numColumns;
}
/**
* @param rowId
* row id
* @return a copy of row data as a dense vector
*/
public DenseVector row(int rowId) {
return row(rowId, true);
}
/**
*
* @param rowId
* row id
* @param deep
* whether to copy data or only shallow copy for executing
* speedup purpose
* @return a vector of a specific row
*/
public DenseVector row(int rowId, boolean deep) {
return new DenseVector(data[rowId], deep);
}
/**
* @param column
* column id
* @return a copy of column data as a dense vector
*/
public DenseVector column(int column) {
DenseVector vec = new DenseVector(numRows);
for (int i = 0; i < numRows; i++)
vec.set(i, data[i][column]);
return vec;
}
/**
* Compute mean of a column of the current matrix
*
* @param column
* column id
* @return mean of a column of the current matrix
*/
public double columnMean(int column) {
double sum = 0.0;
for (int i = 0; i < numRows; i++)
sum += data[i][column];
return sum / numRows;
}
/**
* @return squared sum of all elements of the matrix.
*/
public double squaredSum() {
double res = 0;
for (int i = 0; i < numRows; i++)
for (int j = 0; j < numColumns; j++)
res += data[i][j] * data[i][j];
return res;
}
/**
* @return the matrix norm-2
*/
public double norm() {
double res = 0;
for (int i = 0; i < numRows; i++)
for (int j = 0; j < numColumns; j++)
res += data[i][j] * data[i][j];
return Math.sqrt(res);
}
/**
* row x row of two matrix
*
* @param m
* the first matrix
* @param mrow
* row of the first matrix
* @param n
* the second matrix
* @param nrow
* row of the second matrix
* @return inner product of two row vectors
*/
public static double rowMult(DenseMatrix m, int mrow, DenseMatrix n, int nrow) {
assert m.numColumns == n.numColumns;
double res = 0;
for (int j = 0, k = m.numColumns; j < k; j++)
res += m.get(mrow, j) * n.get(nrow, j);
return res;
}
/**
* column x column of two matrix
*
* @param m
* the first matrix
* @param mcol
* column of the first matrix
* @param n
* the second matrix
* @param ncol
* column of the second matrix
* @return inner product of two column vectors
*/
public static double colMult(DenseMatrix m, int mcol, DenseMatrix n, int ncol) {
assert m.numRows == n.numRows;
double res = 0;
for (int j = 0, k = m.numRows; j < k; j++)
res += m.get(j, mcol) * n.get(j, ncol);
return res;
}
/**
* dot product of row x col between two matrices
*
* @param m
* the first matrix
* @param mrow
* row id of the first matrix
* @param n
* the second matrix
* @param ncol
* column id of the second matrix
* @return dot product of row of the first matrix and column of the second
* matrix
*/
public static double product(DenseMatrix m, int mrow, DenseMatrix n, int ncol) {
assert m.numColumns == n.numRows;
double res = 0;
for (int j = 0; j < m.numColumns; j++)
res += m.get(mrow, j) * n.get(j, ncol);
return res;
}
/**
* Matrix multiplication with a dense matrix
*
* @param mat
* a dense matrix
* @return a dense matrix with results of matrix multiplication
*/
public DenseMatrix mult(DenseMatrix mat) {
assert this.numColumns == mat.numRows;
DenseMatrix res = new DenseMatrix(this.numRows, mat.numColumns);
for (int i = 0; i < res.numRows; i++) {
for (int j = 0; j < res.numColumns; j++) {
double product = 0;
for (int k = 0; k < this.numColumns; k++)
product += data[i][k] * mat.data[k][j];
res.set(i, j, product);
}
}
return res;
}
/**
* Do {@code matrix x vector} between current matrix and a given vector
*
* @return a dense vector with the results of {@code matrix x vector}
*/
public DenseVector mult(DenseVector vec) {
assert this.numColumns == vec.size;
DenseVector res = new DenseVector(this.numRows);
for (int i = 0; i < this.numRows; i++)
res.set(i, row(i, false).inner(vec));
return res;
}
/**
* Get the value at entry [row, column]
*/
public double get(int row, int column) {
return data[row][column];
}
/**
* Set a value to entry [row, column]
*/
public void set(int row, int column, double val) {
data[row][column] = val;
}
/**
* Add a value to entry [row, column]
*/
public void add(int row, int column, double val) {
data[row][column] += val;
}
/**
* @return a new matrix by scaling the current matrix
*/
public DenseMatrix scale(double val) {
DenseMatrix mat = new DenseMatrix(numRows, numColumns);
for (int i = 0; i < numRows; i++)
for (int j = 0; j < numColumns; j++)
mat.data[i][j] = this.data[i][j] * val;
return mat;
}
/**
* Scaling on the current matrix
*/
public void selfScale(double val) {
for (int i = 0; i < numRows; i++)
for (int j = 0; j < numColumns; j++)
this.data[i][j] = this.data[i][j] * val;
}
/**
* Do {@code A + B} matrix operation
*
* @return a matrix with results of {@code C = A + B}
*/
public DenseMatrix add(DenseMatrix mat) {
assert numRows == mat.numRows;
assert numColumns == mat.numColumns;
DenseMatrix res = new DenseMatrix(numRows, numColumns);
for (int i = 0; i < numRows; i++)
for (int j = 0; j < numColumns; j++)
res.data[i][j] = data[i][j] + mat.data[i][j];
return res;
}
public void selfAdd(DenseMatrix mat) {
assert numRows == mat.numRows;
assert numColumns == mat.numColumns;
for (int i = 0; i < numRows; i++)
for (int j = 0; j < numColumns; j++)
this.data[i][j] += mat.data[i][j];
}
/**
* Do {@code A + c} matrix operation, where {@code c} is a constant. Each
* entries will be added by {@code c}
*
* @return a new matrix with results of {@code C = A + c}
*/
public DenseMatrix add(double val) {
DenseMatrix res = new DenseMatrix(numRows, numColumns);
for (int i = 0; i < numRows; i++)
for (int j = 0; j < numColumns; j++)
res.data[i][j] = data[i][j] + val;
return res;
}
/**
* Do {@code A + B} matrix operation
*
* @return a matrix with results of {@code C = A + B}
*/
public DenseMatrix minus(DenseMatrix mat) {
assert numRows == mat.numRows;
assert numColumns == mat.numColumns;
DenseMatrix res = new DenseMatrix(numRows, numColumns);
for (int i = 0; i < numRows; i++)
for (int j = 0; j < numColumns; j++)
res.data[i][j] = data[i][j] - mat.data[i][j];
return res;
}
/**
* Do {@code A + c} matrix operation, where {@code c} is a constant. Each
* entries will be added by {@code c}
*
* @return a new matrix with results of {@code C = A + c}
*/
public DenseMatrix minus(double val) {
DenseMatrix res = new DenseMatrix(numRows, numColumns);
for (int i = 0; i < numRows; i++)
for (int j = 0; j < numColumns; j++)
res.data[i][j] = data[i][j] - val;
return res;
}
/**
* @return the Cholesky decomposition of the current matrix
*/
public DenseMatrix cholesky() {
if (this.numRows != this.numColumns)
throw new RuntimeException("Matrix is not square");
int n = numRows;
DenseMatrix L = new DenseMatrix(n, n);
for (int i = 0; i < n; i++) {
for (int j = 0; j <= i; j++) {
double sum = 0.0;
for (int k = 0; k < j; k++)
sum += L.get(i, k) * L.get(j, k);
double val = i == j ? Math.sqrt(data[i][i] - sum) : (data[i][j] - sum) / L.get(j, j);
L.set(i, j, val);
}
if (Double.isNaN(L.get(i, i)))
return null;
}
return L.transpose();
}
/**
* @return a transposed matrix of current matrix
*/
public DenseMatrix transpose() {
DenseMatrix mat = new DenseMatrix(numColumns, numRows);
for (int i = 0; i < mat.numRows; i++)
for (int j = 0; j < mat.numColumns; j++)
mat.set(i, j, this.data[j][i]);
return mat;
}
/**
* @return a covariance matrix of the current matrix
*/
public DenseMatrix cov() {
DenseMatrix mat = new DenseMatrix(numColumns, numColumns);
for (int i = 0; i < numColumns; i++) {
DenseVector xi = this.column(i);
xi = xi.minus(xi.mean());
mat.set(i, i, xi.inner(xi) / (xi.size - 1));
for (int j = i + 1; j < numColumns; j++) {
DenseVector yi = this.column(j);
double val = xi.inner(yi.minus(yi.mean())) / (xi.size - 1);
mat.set(i, j, val);
mat.set(j, i, val);
}
}
return mat;
}
/**
* Compute the inverse of a matrix by LU decomposition
*
* @return the inverse matrix of current matrix
* @deprecated use {@code inv} instead which is slightly faster
*/
public DenseMatrix inverse() {
if (numRows != numColumns)
throw new RuntimeException("Only square matrix can do inversion");
int n = numRows;
DenseMatrix mat = new DenseMatrix(this);
if (n == 1) {
mat.set(0, 0, 1.0 / mat.get(0, 0));
return mat;
}
int row[] = new int[n];
int col[] = new int[n];
double temp[] = new double[n];
int hold, I_pivot, J_pivot;
double pivot, abs_pivot;
// set up row and column interchange vectors
for (int k = 0; k < n; k++) {
row[k] = k;
col[k] = k;
}
// begin main reduction loop
for (int k = 0; k < n; k++) {
// find largest element for pivot
pivot = mat.get(row[k], col[k]);
I_pivot = k;
J_pivot = k;
for (int i = k; i < n; i++) {
for (int j = k; j < n; j++) {
abs_pivot = Math.abs(pivot);
if (Math.abs(mat.get(row[i], col[j])) > abs_pivot) {
I_pivot = i;
J_pivot = j;
pivot = mat.get(row[i], col[j]);
}
}
}
if (Math.abs(pivot) < 1.0E-10)
throw new RuntimeException("Matrix is singular !");
hold = row[k];
row[k] = row[I_pivot];
row[I_pivot] = hold;
hold = col[k];
col[k] = col[J_pivot];
col[J_pivot] = hold;
// reduce about pivot
mat.set(row[k], col[k], 1.0 / pivot);
for (int j = 0; j < n; j++) {
if (j != k) {
mat.set(row[k], col[j], mat.get(row[k], col[j]) * mat.get(row[k], col[k]));
}
}
// inner reduction loop
for (int i = 0; i < n; i++) {
if (k != i) {
for (int j = 0; j < n; j++) {
if (k != j) {
double val = mat.get(row[i], col[j]) - mat.get(row[i], col[k]) * mat.get(row[k], col[j]);
mat.set(row[i], col[j], val);
}
}
mat.set(row[i], col[k], -mat.get(row[i], col[k]) * mat.get(row[k], col[k]));
}
}
}
// end main reduction loop
// unscramble rows
for (int j = 0; j < n; j++) {
for (int i = 0; i < n; i++)
temp[col[i]] = mat.get(row[i], j);
for (int i = 0; i < n; i++)
mat.set(i, j, temp[i]);
}
// unscramble columns
for (int i = 0; i < n; i++) {
for (int j = 0; j < n; j++)
temp[row[j]] = mat.get(i, col[j]);
for (int j = 0; j < n; j++)
mat.set(i, j, temp[j]);
}
return mat;
}
/**
* NOTE: this implementation (adopted from PREA package) is slightly faster
* than {@code inverse}, especailly when {@code numRows} is large.
*
* @return the inverse matrix of current matrix
*/
public DenseMatrix inv() {
if (this.numRows != this.numColumns)
throw new RuntimeException("Dimensions disagree");
int n = this.numRows;
DenseMatrix mat = DenseMatrix.eye(n);
if (n == 1) {
mat.set(0, 0, 1 / this.get(0, 0));
return mat;
}
DenseMatrix b = new DenseMatrix(this);
for (int i = 0; i < n; i++) {
// find pivot:
double mag = 0;
int pivot = -1;
for (int j = i; j < n; j++) {
double mag2 = Math.abs(b.get(j, i));
if (mag2 > mag) {
mag = mag2;
pivot = j;
}
}
// no pivot (error):
if (pivot == -1 || mag == 0)
return mat;
// move pivot row into position:
if (pivot != i) {
double temp;
for (int j = i; j < n; j++) {
temp = b.get(i, j);
b.set(i, j, b.get(pivot, j));
b.set(pivot, j, temp);
}
for (int j = 0; j < n; j++) {
temp = mat.get(i, j);
mat.set(i, j, mat.get(pivot, j));
mat.set(pivot, j, temp);
}
}
// normalize pivot row:
mag = b.get(i, i);
for (int j = i; j < n; j++)
b.set(i, j, b.get(i, j) / mag);
for (int j = 0; j < n; j++)
mat.set(i, j, mat.get(i, j) / mag);
// eliminate pivot row component from other rows:
for (int k = 0; k < n; k++) {
if (k == i)
continue;
double mag2 = b.get(k, i);
for (int j = i; j < n; j++)
b.set(k, j, b.get(k, j) - mag2 * b.get(i, j));
for (int j = 0; j < n; j++)
mat.set(k, j, mat.get(k, j) - mag2 * mat.get(i, j));
}
}
return mat;
}
/**
* set one value to a specific row
*
* @param row
* row id
* @param val
* value to be set
*/
public void setRow(int row, double val) {
Arrays.fill(data[row], val);
}
/**
* set values of one dense vector to a specific row
*
* @param row
* row id
* @param vals
* values of a dense vector
*/
public void setRow(int row, DenseVector vals) {
for (int j = 0; j < numColumns; j++)
data[row][j] = vals.data[j];
}
/**
* clear and reset all entries to 0
*/
public void clear() {
for (int i = 0; i < numRows; i++)
setRow(i, 0.0);
}
@Override
public String toString() {
return Strings.toString(data);
}
}
================================================
FILE: src/data_structure/DenseVector.java
================================================
// Copyright (C) 2014 Guibing Guo
//
// This file is part of LibRec.
//
// LibRec is free software: you can redistribute it and/or modify
// it under the terms of the GNU General Public License as published by
// the Free Software Foundation, either version 3 of the License, or
// (at your option) any later version.
//
// LibRec is distributed in the hope that it will be useful,
// but WITHOUT ANY WARRANTY; without even the implied warranty of
// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
// GNU General Public License for more details.
//
// You should have received a copy of the GNU General Public License
// along with LibRec. If not, see .
//
package data_structure;
import happy.coding.io.Strings;
import happy.coding.math.Randoms;
import happy.coding.math.Stats;
import java.io.Serializable;
/**
* Data Structure: dense vector
*
* @author guoguibing
*
*/
public class DenseVector implements Serializable {
private static final long serialVersionUID = -2930574547913792430L;
protected int size;
protected double[] data;
/**
* Construct a dense vector with a specific size
*
* @param size
* the size of vector
*/
public DenseVector(int size) {
this.size = size;
data = new double[size];
}
/**
* Construct a dense vector by deeply copying data from a given array
*/
public DenseVector(double[] array) {
this(array, true);
}
/**
* Construct a dense vector by copying data from a given array
*
* @param array
* a given data array
* @param deep
* whether to deep copy array data
*/
public DenseVector(double[] array, boolean deep) {
this.size = array.length;
if (deep) {
data = new double[array.length];
for (int i = 0; i < size; i++)
data[i] = array[i];
} else {
data = array;
}
}
/**
* Construct a dense vector by deeply copying data from a given vector
*/
public DenseVector(DenseVector vec) {
this(vec.data);
}
/**
* Make a deep copy of current vector
*/
public DenseVector clone() {
return new DenseVector(this);
}
public int size() {
return this.size;
}
/**
* Initialize a dense vector with Gaussian values
*/
public void init(double mean, double sigma) {
for (int i = 0; i < size; i++)
data[i] = Randoms.gaussian(mean, sigma);
}
/**
* Initialize a dense vector with uniform values in (0, 1)
*/
public void init() {
for (int i = 0; i < size; i++)
data[i] = Randoms.uniform();
}
/**
* Initialize a dense vector with uniform values in (0, range)
*/
public void init(double range) {
for (int i = 0; i < size; i++)
data[i] = Randoms.uniform(0, range);
}
/**
* Get a value at entry [index]
*/
public double get(int idx) {
return data[idx];
}
/**
* @return vector's data
*/
public double[] getData() {
return data;
}
/**
* @return mean of current vector
*/
public double mean() {
return Stats.mean(data);
}
/**
* @return summation of entries
*/
public double sum(){
return Stats.sum(data);
}
/**
* @return squared summation of entries
*/
public double squaredSum(){
double sum = 0;
for (int i = 0; i < data.length; i ++) {
sum += data[i] * data[i];
}
return sum;
}
/**
* Set a value to entry [index]
*/
public void set(int idx, double val) {
data[idx] = val;
}
/**
* Set a value to all entries
*/
public void setAll(double val) {
for (int i = 0; i < size; i++)
data[i] = val;
}
/**
* Add a value to entry [index]
*/
public void add(int idx, double val) {
data[idx] += val;
}
/**
* Substract a value from entry [index]
*/
public void minus(int idx, double val) {
data[idx] -= val;
}
/**
* @return a dense vector by adding a value to all entries of current vector
*/
public DenseVector add(double val) {
DenseVector result = new DenseVector(size);
for (int i = 0; i < size; i++)
result.data[i] = this.data[i] + val;
return result;
}
/**
* @return a dense vector by substructing a value from all entries of current vector
*/
public DenseVector minus(double val) {
DenseVector result = new DenseVector(size);
for (int i = 0; i < size; i++)
result.data[i] = this.data[i] - val;
return result;
}
/**
* @return a dense vector by scaling a value to all entries of current vector
*/
public DenseVector scale(double val) {
DenseVector result = new DenseVector(size);
for (int i = 0; i < size; i++)
result.data[i] = this.data[i] * val;
return result;
}
public void selfScale(double val) {
for (int i = 0; i < size; i ++)
this.data[i] = this.data[i] * val;
}
/**
* Do vector operation: {@code a + b}
*
* @return a dense vector with results of {@code c = a + b}
*/
public DenseVector add(DenseVector vec) {
assert size == vec.size;
DenseVector result = new DenseVector(size);
for (int i = 0; i < result.size; i++)
result.data[i] = this.data[i] + vec.data[i];
return result;
}
/**
* Vector add operation to itself.
*/
public void selfAdd(DenseVector vec) {
assert size == vec.size;
for (int i = 0; i < size; i ++)
this.data[i] = this.data[i] + vec.data[i];
}
/**
* Do vector operation: {@code a - b}
*
* @return a dense vector with results of {@code c = a - b}
*/
public DenseVector minus(DenseVector vec) {
assert size == vec.size;
DenseVector result = new DenseVector(size);
for (int i = 0; i < vec.size; i++)
result.data[i] = this.data[i] - vec.data[i];
return result;
}
/**
* Do vector operation: {@code a^t * b}
*
* @return the inner product of two vectors
*/
public double inner(DenseVector vec) {
assert size == vec.size;
double result = 0;
for (int i = 0; i < vec.size; i++)
result += get(i) * vec.get(i);
return result;
}
/**
* Do vector operation: {@code a * b^t}
*
* @return the outer product of two vectors
*/
public DenseMatrix outer(DenseVector vec) {
DenseMatrix mat = new DenseMatrix(this.size, vec.size);
for (int i = 0; i < mat.numRows; i++)
for (int j = 0; j < mat.numColumns; j++)
mat.set(i, j, get(i) * vec.get(j));
return mat;
}
@Override
public String toString() {
return Strings.toString(data);
}
}
================================================
FILE: src/data_structure/Pair.java
================================================
package data_structure;
import java.util.Objects;
public class Pair {
public final F first;
public final S second;
public Pair(F first, S second) {
this.first = first;
this.second = second;
}
@Override
public boolean equals(Object o) {
if (!(o instanceof Pair)) {
return false;
}
Pair, ?> p = (Pair, ?>) o;
return Objects.equals(p.first, first) && Objects.equals(p.second, second);
}
@Override
public int hashCode() {
return (first == null ? 0 : first.hashCode()) ^
(second == null ? 0 : second.hashCode());
}
public static Pair create(A a, B b) {
return new Pair(a, b);
}
}
================================================
FILE: src/data_structure/Rating.java
================================================
package data_structure;
public class Rating {
public int userId; // user id, starts from 0
public int itemId; // item id, starts from 0
public float score;
public long timestamp;
public Rating(int userId, int itemId, float score, long timestamp) {
this.userId = userId;
this.itemId = itemId;
this.score = score;
this.timestamp = timestamp;
}
public Rating(String line) {
String[] arr = line.split("\t");
userId = Integer.parseInt(arr[0]);
itemId = Integer.parseInt(arr[1]);
score = Float.parseFloat(arr[2]);
if (arr.length > 3) timestamp = Long.parseLong(arr[3]);
}
public String toString() {
return "<" + userId + "," + itemId + "," + score + "," + timestamp + ">";
}
}
================================================
FILE: src/data_structure/SparseMatrix.java
================================================
package data_structure;
import java.io.Serializable;
import java.util.ArrayList;
import data_structure.Pair;
/**
* This class implements sparse matrix, containing empty values for most space.
*
* @author Joonseok Lee
* @since 2012. 4. 20
* @version 1.1
*/
public class SparseMatrix implements Serializable{
private static final long serialVersionUID = 8003;
/** The number of rows. */
private int M;
/** The number of columns. */
private int N;
/** The array of row references. */
private SparseVector[] rows;
/** The array of column references. */
private SparseVector[] cols;
/*========================================
* Constructors
*========================================*/
/**
* Construct an empty sparse matrix, with a given size.
*
* @param m The number of rows.
* @param n The number of columns.
*/
public SparseMatrix(int m, int n) {
this.M = m;
this.N = n;
rows = new SparseVector[M];
cols = new SparseVector[N];
for (int i = 0; i < M; i++) {
rows[i] = new SparseVector(N);
}
for (int j = 0; j < N; j++) {
cols[j] = new SparseVector(M);
}
}
/**
* Construct an empty sparse matrix, with data copied from another sparse matrix.
*
* @param sm The matrix having data being copied.
*/
public SparseMatrix(SparseMatrix sm) {
this.M = sm.M;
this.N = sm.N;
rows = new SparseVector[M];
cols = new SparseVector[N];
for (int i = 0; i < M; i++) {
rows[i] = sm.getRow(i);
}
for (int j = 0; j < N; j++) {
cols[j] = sm.getCol(j);
}
}
/*========================================
* Getter/Setter
*========================================*/
/**
* Retrieve a stored value from the given index.
*
* @param i The row index to retrieve.
* @param j The column index to retrieve.
* @return The value stored at the given index.
*/
public double getValue(int i, int j) {
return rows[i].getValue(j);
}
/**
* Set a new value at the given index.
*
* @param i The row index to store new value.
* @param j The column index to store new value.
* @param value The value to store.
*/
public void setValue(int i, int j, double value) {
if (value == 0.0) {
rows[i].remove(j);
cols[j].remove(i);
}
else {
rows[i].setValue(j, value);
cols[j].setValue(i, value);
}
}
/**
* Set a new row vector at the given row index.
* @param i The row index to store new vector
* @param newVector
*/
public void setRowVector(int i, SparseVector newVector) {
if (newVector.length() != this.N)
throw new RuntimeException("Vector lengths disagree");
if (i < 0 || i >= this.M)
throw new RuntimeException("Wrong input row index.");
// Clear the values of the current rowVector.
if (rows[i].indexList() != null) {
for (int j : rows[i].indexList()) {
this.setValue(i, j, 0);
}
}
// Set the new vector.
if (newVector.indexList() != null) {
for (int j : newVector.indexList()) {
this.setValue(i, j, newVector.getValue(j));
}
}
}
/**
* Set a new row vector with non-negative constraint at the given row index.
* If the value is negative, set it as 0.
*
* @param i The row index to store new vector
* @param newVector
*/
public void setRowVectorNonnegative(int i, SparseVector newVector) {
if (newVector.length() != this.N)
throw new RuntimeException("Vector lengths disagree");
if (i < 0 || i >= this.M)
throw new RuntimeException("Wrong input row index.");
// Clear the values of the current rowVector.
if (rows[i].indexList() != null) {
for (int j : rows[i].indexList()) {
this.setValue(i, j, 0);
}
}
// Set the new vector with nonnegative constraint.
if (newVector.indexList() != null) {
for (int j : newVector.indexList()) {
double value = newVector.getValue(j);
this.setValue(i, j, value > 0 ? value : 0);
}
}
}
/**
* Set a new col vector at the given col index.
*/
public void setColVector(int j, SparseVector newVector) {
if (newVector.length() != this.M)
throw new RuntimeException("Vector lengths disagree");
if (j < 0 || j >= this.N)
throw new RuntimeException("Wrong input column index.");
// Clear the values of the current colVector
if (cols[j].indexList() != null) {
for (int i : cols[j].indexList()) {
this.setValue(i, j, 0);
}
}
// Set the new vector.
if (newVector.indexList() != null) {
for (int i : newVector.indexList()) {
this.setValue(i, j, newVector.getValue(i));
}
}
}
/**
* Set a new size of the matrix.
*
* @param m The new row count.
* @param n The new column count.
*/
public void setSize(int m, int n) {
this.M = m;
this.N = n;
}
/**
* Return a reference of a given row.
* Make sure to use this method only for read-only purpose.
*
* @param index The row index to retrieve.
* @return A reference to the designated row.
*/
public SparseVector getRowRef(int index) {
return rows[index];
}
/**
* Return a copy of a given row.
* Use this if you do not want to affect to original data.
*
* @param index The row index to retrieve.
* @return A reference to the designated row.
*/
public SparseVector getRow(int index) {
SparseVector newVector = this.rows[index].copy();
return newVector;
}
/**
* Return a reference of a given column.
* Make sure to use this method only for read-only purpose.
*
* @param index The column index to retrieve.
* @return A reference to the designated column.
*/
public SparseVector getColRef(int index) {
return cols[index];
}
/**
* Return a copy of a given column.
* Use this if you do not want to affect to original data.
*
* @param index The column index to retrieve.
* @return A reference to the designated column.
*/
public SparseVector getCol(int index) {
SparseVector newVector = this.cols[index].copy();
return newVector;
}
/**
* Calculate average value for each row.
*
* @param default_value The default average of a row if it has no values.
* @return A SparseVector that each value denotes the average of the row vector.
**/
public SparseVector getRowAverage(double defalut_value) {
SparseVector rowAverage = new SparseVector(this.M);
for (int u = 0; u < this.M; u++) {
SparseVector v = this.getRowRef(u);
double avg = v.average();
if (Double.isNaN(avg)) { // no rate is available: set it as median value.
avg = defalut_value;
}
rowAverage.setValue(u, avg);
}
return rowAverage;
}
/**
* Calculate average value for each column.
*
* @param default_value The default average of a column if it has no values.
* @return A SparseVector that each value denotes the average of the column vector.
*/
public SparseVector getColumnAverage(double defalut_value) {
SparseVector columnAverage = new SparseVector(this.N);
for (int i = 0; i < this.N; i++) {
SparseVector j = this.getColRef(i);
double avg = j.average();
if (Double.isNaN(avg)) { // no rate is available: set it as median value.
avg = defalut_value;
}
columnAverage.setValue(i, avg);
}
return columnAverage;
}
/*========================================
* Properties
*========================================*/
/**
* Capacity of this matrix.
*
* @return An array containing the length of this matrix.
* Index 0 contains row count, while index 1 column count.
*/
public int[] length() {
int[] lengthArray = new int[2];
lengthArray[0] = this.M;
lengthArray[1] = this.N;
return lengthArray;
}
/**
* Size of this matrix, M * N
*/
public int size() {
return M * N;
}
/**
* Actual number of items in the matrix.
*
* @return The number of items in the matrix.
*/
public int itemCount() {
int sum = 0;
if (M > N) {
for (int i = 0; i < M; i++) {
sum += rows[i].itemCount();
}
}
else {
for (int j = 0; j < N; j++) {
sum += cols[j].itemCount();
}
}
return sum;
}
/**
* Number of non-zero elements in the matrix.
*
* @return The number of non-zero elements in the matrix.
*/
public int nonZeroCount() {
int sum = 0;
if (M > N) {
for (int i = 0; i < M; i++) {
sum += rows[i].nonZeroCount();
}
}
else {
for (int j = 0; j < N; j++) {
sum += cols[j].nonZeroCount();
}
}
return sum;
}
/**
* Return items in the diagonal in vector form.
*
* @return Diagonal vector from the matrix.
*/
public SparseVector diagonal() {
SparseVector v = new SparseVector(Math.min(this.M, this.N));
for (int i = 0; i < Math.min(this.M, this.N); i++) {
double value = this.getValue(i, i);
if (value > 0.0) {
v.setValue(i, value);
}
}
return v;
}
/**
* The value of maximum element in the matrix.
*
* @return The maximum value.
*/
public double max() {
double curr = Double.MIN_VALUE;
for (int i = 0; i < this.M; i++) {
SparseVector v = this.getRowRef(i);
if (v.itemCount() > 0) {
double rowMax = v.max();
if (v.max() > curr) {
curr = rowMax;
}
}
}
return curr;
}
/**
* The value of minimum element in the matrix.
*
* @return The minimum value.
*/
public double min() {
double curr = Double.MAX_VALUE;
for (int i = 0; i < this.M; i++) {
SparseVector v = this.getRowRef(i);
if (v.itemCount() > 0) {
double rowMin = v.min();
if (v.min() < curr) {
curr = rowMin;
}
}
}
return curr;
}
/**
* Sum of every element. It ignores non-existing values.
*
* @return The sum of all elements.
*/
public double sum() {
double sum = 0.0;
for (int i = 0; i < this.M; i++) {
SparseVector v = this.getRowRef(i);
sum += v.sum();
}
return sum;
}
/**
* Square sum of all elements. It ignores non-existing values.
*
* @return The square sum of all elements
*/
public double squareSum() {
double sum = 0.0;
for (int i = 0; i < this.M; i++) {
SparseVector v = this.getRowRef(i);
sum += v.squareSum();
}
return sum;
}
/**
* Average of every element. It ignores non-existing values.
*
* @return The average value.
*/
public double average() {
return this.sum() / this.itemCount();
}
/**
* Variance of every element. It ignores non-existing values.
*
* @return The variance value.
*/
public double variance() {
double avg = this.average();
double sum = 0.0;
for (int i = 0; i < this.M; i++) {
ArrayList itemList = this.getRowRef(i).indexList();
for (int j : itemList) {
sum += Math.pow(this.getValue(i, j) - avg, 2);
}
}
return sum / this.itemCount();
}
/**
* Standard Deviation of every element. It ignores non-existing values.
*
* @return The standard deviation value.
*/
public double stdev() {
return Math.sqrt(this.variance());
}
/**
* Return the (non-zero) index pairs.
* @return
*/
public ArrayList> indexPairs() {
ArrayList> pairs = new ArrayList>();
for (int i = 0; i < M; i ++) {
for (int j : rows[i].indexList()) {
pairs.add(new Pair(i, j));
}
}
return pairs;
}
/*========================================
* Matrix operations
*========================================*/
/**
* Scalar subtraction (aX).
*
* @param alpha The scalar value to be multiplied to this matrix.
* @return The resulting matrix after scaling.
*/
public SparseMatrix scale(double alpha) {
SparseMatrix A = new SparseMatrix(this.M, this.N);
for (int i = 0; i < A.M; i++) {
A.rows[i] = this.getRowRef(i).scale(alpha);
}
for (int j = 0; j < A.N; j++) {
A.cols[j] = this.getColRef(j).scale(alpha);
}
return A;
}
/**
* Scalar subtraction (aX) on the matrix itself.
* This is used for minimizing memory usage.
*
* @param alpha The scalar value to be multiplied to this matrix.
*/
public SparseMatrix selfScale(double alpha) {
for (int i = 0; i < this.M; i++) {
ArrayList itemList = this.getRowRef(i).indexList();
for (int j : itemList) {
this.setValue(i, j, this.getValue(i, j) * alpha);
}
}
return this;
}
/**
* Scalar addition.
* @param alpha The scalar value to be added to this matrix.
* @return The resulting matrix after addition.
*/
public SparseMatrix add(double alpha) {
SparseMatrix A = new SparseMatrix(this.M, this.N);
for (int i = 0; i < A.M; i++) {
A.rows[i] = this.getRowRef(i).add(alpha);
}
for (int j = 0; j < A.N; j++) {
A.cols[j] = this.getColRef(j).add(alpha);
}
return A;
}
/**
* Scalar addition on the matrix itself.
* @param alpha The scalar value to be added to this matrix.
*/
public void selfAdd(double alpha) {
for (int i = 0; i < this.M; i++) {
ArrayList itemList = this.getRowRef(i).indexList();
for (int j : itemList) {
this.setValue(i, j, this.getValue(i, j) + alpha);
}
}
}
/**
* Exponential of a given constant.
*
* @param alpha The exponent.
* @return The resulting exponential matrix.
*/
public SparseMatrix exp(double alpha) {
for (int i = 0; i < this.M; i++) {
SparseVector b = this.getRowRef(i);
ArrayList indexList = b.indexList();
for (int j : indexList) {
this.setValue(i, j, Math.pow(alpha, this.getValue(i, j)));
}
}
return this;
}
/**
* The transpose of the matrix.
* This is simply implemented by interchanging row and column each other.
*
* @return The transpose of the matrix.
*/
public SparseMatrix transpose() {
SparseMatrix A = new SparseMatrix(this.N, this.M);
A.cols = this.rows;
A.rows = this.cols;
return A;
}
/**
* Matrix-vector product (b = Ax)
*
* @param x The vector to be multiplied to this matrix.
* @throws RuntimeException when dimensions disagree
* @return The resulting vector after multiplication.
*/
public SparseVector times(SparseVector x) {
if (N != x.length())
throw new RuntimeException("Dimensions disagree");
SparseMatrix A = this;
SparseVector b = new SparseVector(M);
for (int i = 0; i < M; i++) {
b.setValue(i, A.rows[i].innerProduct(x));
}
return b;
}
/**
* Matrix-matrix product (C = AB)
*
* @param B The matrix to be multiplied to this matrix.
* @throws RuntimeException when dimensions disagree
* @return The resulting matrix after multiplication.
*/
public SparseMatrix times(SparseMatrix B) {
// original implementation
if (N != (B.length())[0])
throw new RuntimeException("Dimensions disagree");
SparseMatrix A = this;
SparseMatrix C = new SparseMatrix(M, (B.length())[1]);
for (int i = 0; i < M; i++) {
for (int j = 0; j < (B.length())[1]; j++) {
SparseVector x = A.getRowRef(i);
SparseVector y = B.getColRef(j);
if (x != null && y != null)
C.setValue(i, j, x.innerProduct(y));
else
C.setValue(i, j, 0.0);
}
}
return C;
}
/**
* Element-wise matrix product (C_ij = A_ij * B_ij)
* @param B
* @return
*/
public SparseMatrix dotTimes(SparseMatrix B) {
if (M != B.M || N != B.N) {
throw new RuntimeException("dotTimes: Matrices are not of the same size!");
}
SparseMatrix C = new SparseMatrix(M, N);
for (int i = 0; i < M; i ++) {
ArrayList wordList = this.rows[i].indexList();
for (int j : wordList) {
double A_ij = getValue(i, j);
double B_ij = B.getValue(i, j);
if (A_ij != 0 && B_ij != 0) {
C.setValue(i, j, A_ij * B_ij);
}
}
}
return C;
}
/**
* Element-wise matrix division (C_ij = A_ij / B_ij)
* It ignore 0 elements.
* @param B
* @return
*/
public SparseMatrix dotDivide(SparseMatrix B) {
if (M != B.M || N != B.N) {
throw new RuntimeException("dotDivide: Matrices are not of the same size!");
}
SparseMatrix C = new SparseMatrix(M, N);
for (int i = 0; i < M; i ++) {
ArrayList wordList = this.rows[i].indexList();
for (int j : wordList) {
double A_ij = getValue(i, j);
double B_ij = B.getValue(i, j);
if (A_ij != 0 && B_ij != 0) {
C.setValue(i, j, A_ij / B_ij);
}
}
}
return C;
}
/** TF-IDF term weighting on an itemWords Matrix (row denotes item, column denotes word, each value is an integer).
*
* @return TF-IDF term weighted matrix.
*/
public SparseMatrix tfidf() {
SparseMatrix C = new SparseMatrix(M, N);
for (int i = 0; i < M; i++) { // row represents a doc
ArrayList wordList = rows[i].indexList();
for (int j : wordList) { // col represent a word
if (this.getValue(i, j) != 0) {
double TF = 1 + log2(getValue(i, j));
double IDF = log2((double)M / cols[j].itemCount());
C.setValue(i, j, TF * IDF);
}
}
}
return C;
}
/**
* IDF term weighting on an itemWords matrix.
* @return
*/
public SparseMatrix idf() {
SparseMatrix C = new SparseMatrix(M, N);
for (int i = 0; i < M; i++) { // row represents a doc
ArrayList wordList = rows[i].indexList();
for (int j : wordList) { // col represent a word
if (this.getValue(i, j) != 0) {
double TF = 1;
double IDF = log2((double)M / cols[j].itemCount());
C.setValue(i, j, TF * IDF);
}
}
}
return C;
}
/**
* TF term weighting on an itemWords Matrix (row denotes item, column denotes word, each value is an integer).
* @return TF term weighted matrix.
*/
public SparseMatrix tf() {
SparseMatrix C = new SparseMatrix(M, N);
for (int i = 0; i < M; i++) { // row represents a doc
ArrayList wordList = rows[i].indexList();
for (int j : wordList) { // col represent a word
if (this.getValue(i, j) != 0) {
double TF = 1 + log2(getValue(i, j));
C.setValue(i, j, TF);
}
}
}
return C;
}
public SparseMatrix log2() {
SparseMatrix C = new SparseMatrix(M, N);
for (int i = 0; i < M; i++) {
ArrayList indexList = this.getRowRef(i).indexList();
for (int j : indexList) {
C.setValue(i, j, 1 + log2(this.getValue(i, j)));
}
}
return C;
}
private double log2(double n) {
return Math.log(n) / Math.log(2);
}
/** Convert a non-negative matrix to a row stochastic matrix (i.e. sum of a row is 1).
* It ignores 0 row vector.
*
* @return Row stochastic matrix
*/
public SparseMatrix rowStochastic() {
SparseMatrix C = new SparseMatrix(M, N);
for (int i = 0; i < M; i++) {
double sum = rows[i].sum();
if (sum != 0) {
for (int j : this.rows[i].indexList()) {
C.setValue(i, j, getValue(i, j) / sum);
}
}
}
return C;
}
/**
* Apply L2 norm on each row vector.
* It ignores 0 row vector.
* @return
*/
public SparseMatrix rowL2Norm() {
SparseMatrix C = new SparseMatrix(M, N);
for (int i = 0; i < M; i++) {
double squareSum = rows[i].squareSum();
if (squareSum != 0) {
double l2_norm = Math.sqrt(squareSum);
for (int j : rows[i].indexList()) {
C.setValue(i, j, getValue(i, j) / l2_norm);
}
}
}
return C;
}
/** Convert a non-negative matrix to a column stochastic matrix (i.e. sum of a column is 1).
* It ignores 0 column vector.
*
* @return Column stochastic matrix.
*/
public SparseMatrix colStochastic() {
SparseMatrix C = new SparseMatrix(this.M, this.N);
for (int j = 0; j < this.N; j++) {
double sum = this.cols[j].sum();
if (sum != 0) {
for (int i : this.cols[j].indexList()) {
C.setValue(i, j, this.getValue(i, j) / sum);
}
}
}
return C;
}
/**
* Matrix-matrix product (A = AB), without using extra memory.
*
* @param B The matrix to be multiplied to this matrix.
* @throws RuntimeException when dimensions disagree
*/
public void selfTimes(SparseMatrix B) {
// original implementation
if (N != (B.length())[0])
throw new RuntimeException("Dimensions disagree");
for (int i = 0; i < M; i++) {
SparseVector tmp = new SparseVector(N);
for (int j = 0; j < (B.length())[1]; j++) {
SparseVector x = this.getRowRef(i);
SparseVector y = B.getColRef(j);
if (x != null && y != null)
tmp.setValue(j, x.innerProduct(y));
else
tmp.setValue(j, 0.0);
}
for (int j = 0; j < (B.length())[1]; j++) {
this.setValue(i, j, tmp.getValue(j));
}
}
}
/**
* Matrix-matrix sum (C = A + B)
*
* @param B The matrix to be added to this matrix.
* @throws RuntimeException when dimensions disagree
* @return The resulting matrix after summation.
*/
public SparseMatrix plus(SparseMatrix B) {
SparseMatrix A = this;
if (A.M != B.M || A.N != B.N)
throw new RuntimeException("Dimensions disagree");
SparseMatrix C = new SparseMatrix(M, N);
for (int i = 0; i < M; i++) {
C.rows[i] = A.rows[i].plus(B.rows[i]);
}
for (int j = 0; j < N; j++) {
C.cols[j] = A.cols[j].plus(B.cols[j]);
}
return C;
}
/**
* Matrix-matrix minus (C = A - B)
*
* @param B The matrix to be deducted to this matrix.
* @throws RuntimeException when dimensions disagree
* @return The resulting matrix after minus.
*/
public SparseMatrix minus(SparseMatrix B) {
SparseMatrix A = this;
if (A.M != B.M || A.N != B.N)
throw new RuntimeException("Dimensions disagree");
SparseMatrix C = new SparseMatrix(M, N);
for (int i = 0; i < M; i++) {
C.rows[i] = A.rows[i].minus(B.rows[i]);
}
for (int j = 0; j < N; j++) {
C.cols[j] = A.cols[j].minus(B.cols[j]);
}
return C;
}
/**
* Generate an identity matrix with the given size.
*
* @param n The size of requested identity matrix.
* @return An identity matrix with the size of n by n.
*/
public static SparseMatrix makeIdentity(int n) {
SparseMatrix m = new SparseMatrix(n, n);
for (int i = 0; i < n; i++) {
m.setValue(i, i, 1.0);
}
return m;
}
/**
* Generate a uniform matrix with the given size.
* The sum of each row is 1.
* @param m
* @param n
* @return
*/
public static SparseMatrix makeUniform(int M, int N) {
SparseMatrix m = new SparseMatrix(M, N);
for (int i = 0; i < M; i ++) {
for (int j = 0; j < N; j++) {
m.setValue(i, j, 1.0 / N);
}
}
return m;
}
/**
* Generate a random matrix with the given size and sparseRate.
* Each entry is in the range [0,1]
* @param M
* @param N
* @param sparseRate
* @return
*/
public static SparseMatrix makeRandom(int M, int N, double sparseRate) {
if (sparseRate <=0 || sparseRate >1) {
throw new RuntimeException("SparseRate input error!");
}
SparseMatrix m = new SparseMatrix(M, N);
for (int i = 0; i < M; i ++) {
for (int j = 0; j < N; j ++) {
double random = Math.random();
if (random < sparseRate) {
m.setValue(i, j, Math.random());
}
}
}
return m;
}
/**
* Calculate inverse matrix.
*
* @throws RuntimeException when dimensions disagree.
* @return The inverse of current matrix.
*/
public SparseMatrix inverse() {
if (this.M != this.N)
throw new RuntimeException("Dimensions disagree");
SparseMatrix original = this;
SparseMatrix newMatrix = makeIdentity(this.M);
int n = this.M;
if (n == 1) {
newMatrix.setValue(0, 0, 1 / original.getValue(0, 0));
return newMatrix;
}
SparseMatrix b = new SparseMatrix(original);
for (int i = 0; i < n; i++) {
// find pivot:
double mag = 0;
int pivot = -1;
for (int j = i; j < n; j++) {
double mag2 = Math.abs(b.getValue(j, i));
if (mag2 > mag) {
mag = mag2;
pivot = j;
}
}
// no pivot (error):
if (pivot == -1 || mag == 0) {
return newMatrix;
}
// move pivot row into position:
if (pivot != i) {
double temp;
for (int j = i; j < n; j++) {
temp = b.getValue(i, j);
b.setValue(i, j, b.getValue(pivot, j));
b.setValue(pivot, j, temp);
}
for (int j = 0; j < n; j++) {
temp = newMatrix.getValue(i, j);
newMatrix.setValue(i, j, newMatrix.getValue(pivot, j));
newMatrix.setValue(pivot, j, temp);
}
}
// normalize pivot row:
mag = b.getValue(i, i);
for (int j = i; j < n; j ++) {
b.setValue(i, j, b.getValue(i, j) / mag);
}
for (int j = 0; j < n; j ++) {
newMatrix.setValue(i, j, newMatrix.getValue(i, j) / mag);
}
// eliminate pivot row component from other rows:
for (int k = 0; k < n; k ++) {
if (k == i)
continue;
double mag2 = b.getValue(k, i);
for (int j = i; j < n; j ++) {
b.setValue(k, j, b.getValue(k, j) - mag2 * b.getValue(i, j));
}
for (int j = 0; j < n; j ++) {
newMatrix.setValue(k, j, newMatrix.getValue(k, j) - mag2 * newMatrix.getValue(i, j));
}
}
}
return newMatrix;
}
/**
* Calculate Cholesky decomposition of the matrix.
*
* @throws RuntimeException when matrix is not square.
* @return The Cholesky decomposition result.
*/
public SparseMatrix cholesky() {
if (this.M != this.N)
throw new RuntimeException("Matrix is not square");
SparseMatrix A = this;
int n = A.M;
SparseMatrix L = new SparseMatrix(n, n);
for (int i = 0; i < n; i++) {
for (int j = 0; j <= i; j++) {
double sum = 0.0;
for (int k = 0; k < j; k++) {
sum += L.getValue(i, k) * L.getValue(j, k);
}
if (i == j) {
L.setValue(i, i, Math.sqrt(A.getValue(i, i) - sum));
}
else {
L.setValue(i, j, 1.0 / L.getValue(j, j) * (A.getValue(i, j) - sum));
}
}
if (Double.isNaN(L.getValue(i, i))) {
//throw new RuntimeException("Matrix not positive definite: (" + i + ", " + i + ")");
return null;
}
}
return L.transpose();
}
/**
* Generate a covariance matrix of the current matrix.
*
* @return The covariance matrix of the current matrix.
*/
public SparseMatrix covariance() {
int columnSize = this.N;
SparseMatrix cov = new SparseMatrix(columnSize, columnSize);
for (int i = 0; i < columnSize; i++) {
for (int j = i; j < columnSize; j++) {
SparseVector data1 = this.getCol(i);
SparseVector data2 = this.getCol(j);
double avg1 = data1.average();
double avg2 = data2.average();
double value = data1.sub(avg1).innerProduct(data2.sub(avg2)) / (data1.length()-1);
cov.setValue(i, j, value);
cov.setValue(j, i, value);
}
}
return cov;
}
/*========================================
* Matrix operations (partial)
*========================================*/
/**
* Scalar Multiplication only with indices in indexList.
*
* @param alpha The scalar to be multiplied to this matrix.
* @param indexList The list of indices to be applied summation.
* @return The resulting matrix after scaling.
*/
public SparseMatrix partScale(double alpha, int[] indexList) {
if (indexList != null) {
for (int i : indexList) {
for (int j : indexList) {
this.setValue(i, j, this.getValue(i, j) * alpha);
}
}
}
return this;
}
/**
* Matrix summation (A = A + B) only with indices in indexList.
*
* @param B The matrix to be added to this matrix.
* @param indexList The list of indices to be applied summation.
* @throws RuntimeException when dimensions disagree.
* @return The resulting matrix after summation.
*/
public SparseMatrix partPlus(SparseMatrix B, int[] indexList) {
if (indexList != null) {
if (this.M != B.M || this.N != B.N)
throw new RuntimeException("Dimensions disagree");
for (int i : indexList) {
this.rows[i].partPlus(B.rows[i], indexList);
}
for (int j : indexList) {
this.cols[j].partPlus(B.cols[j], indexList);
}
}
return this;
}
/**
* Matrix subtraction (A = A - B) only with indices in indexList.
*
* @param B The matrix to be subtracted from this matrix.
* @param indexList The list of indices to be applied subtraction.
* @throws RuntimeException when dimensions disagree.
* @return The resulting matrix after subtraction.
*/
public SparseMatrix partMinus(SparseMatrix B, int[] indexList) {
if (indexList != null) {
if (this.M != B.M || this.N != B.N)
throw new RuntimeException("Dimensions disagree");
for (int i : indexList) {
this.rows[i].partMinus(B.rows[i], indexList);
}
for (int j : indexList) {
this.cols[j].partMinus(B.cols[j], indexList);
}
}
return this;
}
/**
* Matrix-vector product (b = Ax) only with indices in indexList.
*
* @param x The vector to be multiplied to this matrix.
* @param indexList The list of indices to be applied multiplication.
* @return The resulting vector after multiplication.
*/
public SparseVector partTimes(SparseVector x, int[] indexList) {
if (indexList == null)
return x;
SparseVector b = new SparseVector(M);
for (int i : indexList) {
b.setValue(i, this.rows[i].partInnerProduct(x, indexList));
}
return b;
}
/**
* Convert the matrix to a printable string.
*
* @return The resulted string in the form of "(1, 2: 5.0) (2, 4: 4.5)"
*/
@Override
public String toString() {
String s = "";
for (int i = 0; i < this.M; i++) {
SparseVector row = this.getRowRef(i);
if (row.itemCount() == 0) continue;
for (int j : row.indexList()) {
s += "(" + i + ", " + j + ": " + this.getValue(i, j) + ") ";
}
s += "\r\n";
}
return s;
}
}
================================================
FILE: src/data_structure/SparseVector.java
================================================
package data_structure;
import java.io.Serializable;
import java.util.ArrayList;
import java.util.HashMap;
import java.util.HashSet;
import utils.CommonUtils;
/**
* This class implements sparse vector, containing empty values for most space.
*
* @author Joonseok Lee
* @since 2012. 4. 20
* @version 1.1
*/
public class SparseVector implements Serializable{
private static final long serialVersionUID = 8002;
/** The length (maximum number of items to be stored) of sparse vector. */
private int N;
/** Data map for pairs. */
private DataMap map;
/*========================================
* Constructors
*========================================*/
/**
* Construct an empty sparse vector, with capacity 0.
* Capacity can be reset with setLength method later.
*/
public SparseVector() {
this.N = 0;
this.map = new DataMap();
}
/**
* Construct a new sparse vector with size n.
*
* @param n The capacity of new sparse vector.
*/
public SparseVector(int n) {
this.N = n;
this.map = new DataMap();
}
/**
* Construct an empty sparse vector, with data copied from another sparse vector.
*
* @param sv The vector having data being copied.
*/
public SparseVector(SparseVector sv) {
this.N = sv.N;
this.map = new DataMap();
for (int i = 0; i < N; i++) {
this.setValue(i, sv.getValue(i));
}
}
/*========================================
* Getter/Setter
*========================================*/
/**
* Set a new value at the given index.
*
* @param i The index to store new value.
* @param value The value to store.
*/
public void setValue(int i, double value) {
if (value == 0.0)
map.remove(i);
else
map.put(i, value);
}
/**
* Set the values of current vector as newVector
* @param newVector
*/
public void setVector(SparseVector newVector) {
if (this.length() != newVector.length()) {
throw new RuntimeException("Vector length disagrees.");
}
ArrayList indexList = this.indexList();
for (int i : indexList)
this.setValue(i, 0);
indexList = newVector.indexList();
for (int i : indexList)
this.setValue(i, newVector.getValue(i));
}
/**
* Retrieve a stored value from the given index.
*
* @param i The index to retrieve.
* @return The value stored at the given index.
*/
public double getValue(int i) {
if (map.contains(i))
return map.get(i);
else
return 0.0;
}
/**
* Delete a value stored at the given index.
*
* @param i The index to delete the value in it.
*/
public void remove(int i) {
if (map.contains(i))
map.remove(i);
}
/**
* Copy the whole sparse vector and make a clone.
*
* @return A clone of the current sparse vector, containing same values.
*/
public SparseVector copy() {
SparseVector newVector = new SparseVector(this.N);
for (int i : this.map) {
newVector.setValue(i, this.getValue(i));
}
return newVector;
}
/**
* Get an Arraylist of existing indices.
* @return An arraylist of integer, contain indices with valid items.
*/
public ArrayList indexList() {
if (this.itemCount() == 0)
return new ArrayList();
ArrayList result = new ArrayList();
for (int i : this.map) {
result.add(i);
}
return result;
}
/**
* Get a HashSet of existing indices.
* @return A hashset of integer, contain indices with valid items.
*/
public HashSet indexSet() {
if (this.itemCount() == 0)
return new HashSet();
HashSet result = new HashSet();
for (int i : this.map) {
result.add(i);
}
return result;
}
/**
* Set a same value to every element.
*
* @param value The value to assign to every element.
*/
public void initialize(double value) {
for (int i = 0; i < this.N; i++) {
this.setValue(i, value);
}
}
/**
* Set same value to given indices.
*
* @param index The list of indices, which will be assigned the new value.
* @param value The new value to be assigned.
*/
public void initialize(int[] index, double value) {
for (int i = 0; i < index.length; i++) {
this.setValue(index[i], value);
}
}
/*========================================
* Properties
*========================================*/
/**
* Capacity of this vector.
*
* @return The length of sparse vector
*/
public int length() {
return N;
}
/**
* Actual number of items in the vector.
*
* @return The number of items in the vector.
*/
public int itemCount() {
return map.itemCount();
}
/**
* Number of non-zero elements in the vector.
*
* @return The number of non-zero elements in the vector.
*/
public int nonZeroCount() {
int count = 0;
for (int i : map) {
if (map.get(i) != 0)
count ++;
}
return count;
}
/**
* Set a new capacity of the vector.
*
* @param n The new capacity value.
*/
public void setLength(int n) {
this.N = n;
}
/*========================================
* Unary Vector operations
*========================================*/
/**
* Scalar addition operator.
*
* @param alpha The scalar value to be added to the original vector.
* @return The resulting vector, added by alpha.
*/
public SparseVector add(double alpha) {
SparseVector a = this;
SparseVector c = new SparseVector(N);
for (int i : a.map) {
c.setValue(i, alpha + a.getValue(i));
}
return c;
}
/**
* Scalar subtraction operator.
*
* @param alpha The scalar value to be subtracted from the original vector.
* @return The resulting vector, subtracted by alpha.
*/
public SparseVector sub(double alpha) {
SparseVector a = this;
SparseVector c = new SparseVector(N);
for (int i : a.map) {
c.setValue(i, a.getValue(i) - alpha);
}
return c;
}
/**
* Scalar multiplication operator.
*
* @param alpha The scalar value to be multiplied to the original vector.
* @return The resulting vector, multiplied by alpha.
*/
public SparseVector scale(double alpha) {
SparseVector a = this;
SparseVector c = new SparseVector(N);
if (alpha == 0)
return c;
for (int i : a.map) {
c.setValue(i, alpha * a.getValue(i));
}
return c;
}
/**
* Scale multiplication operator on vector itself.
* @param alpha
* @return
*/
public SparseVector selfScale(double alpha) {
SparseVector a = this;
for (int i : a.map) {
a.setValue(i, alpha * a.getValue(i));
}
return a;
}
/**
* Scalar power operator.
*
* @param alpha The scalar value to be powered to the original vector.
* @return The resulting vector, powered by alpha.
*/
public SparseVector power(double alpha) {
SparseVector a = this;
SparseVector c = new SparseVector(N);
for (int i : a.map) {
c.setValue(i, Math.pow(a.getValue(i), alpha));
}
return c;
}
/**
* Exponential of a given constant.
*
* @param alpha The exponent.
* @return The resulting exponential vector.
*/
public SparseVector exp(double alpha) {
SparseVector a = this;
SparseVector c = new SparseVector(N);
for (int i : a.map) {
c.setValue(i, Math.pow(alpha, a.getValue(i)));
}
return c;
}
public SparseVector log2() {
SparseVector c = new SparseVector(N);
for (int i : this.map) {
c.setValue(i, 1 + log2(this.getValue(i)));
}
return c;
}
private double log2(double n) {
return Math.log(n) / Math.log(2);
}
/**
* Return a uniform vector of size n.
* @param n
*/
public static SparseVector makeUniform(int n) {
SparseVector v = new SparseVector(n);
double val = 1.0 / n;
for (int i = 0; i < n; i++) {
v.setValue(i, val);
}
return v;
}
/**
* Randomly generate a vector of dimension m. Each value is in the range [0,1]
* @param m
* @return
*/
public static SparseVector makeRandom(int m) {
SparseVector a = new SparseVector(m);
for (int i = 0; i < m; i++) {
a.setValue(i, Math.random());
}
return a;
}
/**
* Calculate cosine similarity of two sparse vectors.
* @param a
* @param b
* @return
*/
public static double cosineSimilarity(SparseVector a, SparseVector b) {
if (a.itemCount() == 0 || b.itemCount() == 0)
return 0;
double innerProduct = a.innerProduct(b);
return innerProduct == 0 ? 0 :
innerProduct / (Math.sqrt(a.squareSum()) * Math.sqrt(b.squareSum()));
}
/**
* 2-norm of the vector.
*
* @return 2-norm value of the vector.
*/
public double norm() {
SparseVector a = this;
return Math.sqrt(a.innerProduct(a));
}
/**
* L1 norm (sum of elements is 1) of the vector.
* @return L1-norm of the vector.
*/
public SparseVector L1_norm() {
double sum = this.sum();
return this.scale(1.0 / sum);
}
/**
* Sum of every element in the vector.
*
* @return Sum value of every element.
*/
public double sum() {
SparseVector a = this;
double sum = 0.0;
for (int i : a.map) {
sum += a.getValue(i);
}
return sum;
}
/**
* Square sum of all elements in the vector.
*
* @return Square sum of all elements.
*/
public double squareSum() {
return this.innerProduct(this);
}
/**
* The value of maximum element in the vector.
*
* @return Maximum value in the vector.
*/
public double max() {
SparseVector a = this;
double curr = Double.MIN_VALUE;
for (int i : a.map) {
if (a.getValue(i) > curr) {
curr = a.getValue(i);
}
}
return curr;
}
/**
* The value of minimum element in the vector.
*
* @return Minimum value in the vector.
*/
public double min() {
SparseVector a = this;
double curr = Double.MAX_VALUE;
for (int i : a.map) {
if (a.getValue(i) < curr) {
curr = a.getValue(i);
}
}
return curr;
}
/**
* Sum of absolute value of every element in the vector.
*
* @return Sum of absolute value of every element.
*/
public double absoluteSum() {
SparseVector a = this;
double sum = 0.0;
for (int i : a.map) {
sum += Math.abs(a.getValue(i));
}
return sum;
}
/**
* Average of every element. It ignores non-existing values.
*
* @return The average value.
*/
public double average() {
SparseVector a = this;
return a.sum() / (double) a.itemCount();
}
/**
* Variance of every element. It ignores non-existing values.
*
* @return The variance value.
*/
public double variance() {
double avg = this.average();
double sum = 0.0;
for (int i : this.map) {
sum += Math.pow(this.getValue(i) - avg, 2);
}
return sum / this.itemCount();
}
/**
* Standard Deviation of every element. It ignores non-existing values.
*
* @return The standard deviation value.
*/
public double stdev() {
return Math.sqrt(this.variance());
}
/*========================================
* Binary Vector operations
*========================================*/
/**
* Vector sum (a + b)
*
* @param b The vector to be added to this vector.
* @return The resulting vector after summation.
*/
public SparseVector plus(SparseVector b) {
SparseVector a = this;
if (a.N != b.N)
throw new RuntimeException("Vector lengths disagree");
SparseVector c = new SparseVector(N);
for (int i : a.map)
c.setValue(i, a.getValue(i)); // c = a
for (int i : b.map)
c.setValue(i, b.getValue(i) + c.getValue(i)); // c = c + b
return c;
}
/**
* Vector sum on itself (a + b)
* @param b
* @return
*/
public SparseVector selfPlus(SparseVector b) {
SparseVector a = this;
if (a.N != b.N)
throw new RuntimeException("Vector lengths disagree");
for (int i : b.map) {
a.setValue(i, a.getValue(i) + b.getValue(i));
}
return a;
}
/**
* Vector subtraction (a - b)
*
* @param b The vector to be subtracted from this vector.
* @return The resulting vector after subtraction.
*/
public SparseVector minus(SparseVector b) {
SparseVector a = this;
if (a.N != b.N)
throw new RuntimeException("Vector lengths disagree");
SparseVector c = new SparseVector(N);
for (int i : a.map)
c.setValue(i, a.getValue(i)); // c = a
for (int i : b.map)
c.setValue(i, c.getValue(i) - b.getValue(i)); // c = c - b
return c;
}
/**
* Vector subtraction on itself (a - b)
* @param b
* @return
*/
public SparseVector selfMinus(SparseVector b) {
SparseVector a = this;
if (a.N != b.N)
throw new RuntimeException("Vector lengths disagree");
for (int i : b.map) {
a.setValue(i, a.getValue(i) - b.getValue(i));
}
return a;
}
/**
* Vector subtraction (a - b), for only existing values.
* The resulting vector can have a non-zero value only if both vectors have a value at the index.
*
* @param b The vector to be subtracted from this vector.
* @return The resulting vector after subtraction.
*/
public SparseVector commonMinus(SparseVector b) {
SparseVector a = this;
// if (a.N != b.N)
// throw new RuntimeException("Vector lengths disagree");
SparseVector c = new SparseVector(N);
if (a.itemCount() <= b.itemCount()) {
for (int i : a.map) {
if (b.map.contains(i)) c.setValue(i, a.getValue(i) - b.getValue(i));
}
}
else {
for (int i : b.map) {
if (a.map.contains(i)) c.setValue(i, a.getValue(i) - b.getValue(i));
}
}
return c;
}
/**
* Inner product of two vectors.
*
* @param b The vector to be inner-producted with this vector.
* @return The inner-product value.
*/
public double innerProduct(SparseVector b) {
SparseVector a = this;
double sum = 0.0;
if (a.N != b.N)
throw new RuntimeException("Vector lengths disagree");
// iterate over the vector with the fewer items
if (a.itemCount() <= b.itemCount()) {
for (int i : a.map) {
if (b.map.contains(i)) sum += a.getValue(i) * b.getValue(i);
}
}
else {
for (int i : b.map) {
if (a.map.contains(i)) sum += a.getValue(i) * b.getValue(i);
}
}
return sum;
}
/**
* Outer product of two vectors.
*
* @param b The vector to be outer-producted with this vector.
* @return The resulting outer-product matrix.
*/
public SparseMatrix outerProduct(SparseVector b) {
SparseMatrix A = new SparseMatrix(this.N, b.N);
for (int i = 0; i < this.N; i++) {
for (int j = 0; j < b.N; j++) {
A.setValue(i, j, this.getValue(i) * b.getValue(j));
}
}
return A;
}
/**
* Dot product of two vectors (c_i = a_i * b_i)
* @param b
* @return The resulting doc-product vector.
*/
public SparseVector dotProduct(SparseVector b) {
if (N != b.N)
throw new RuntimeException("dotProduct Error - Vector lengths disagree");
SparseVector c = new SparseVector(N);
for (int i : map) {
if (getValue(i) != 0 && b.getValue(i)!= 0)
c.setValue(i, getValue(i) * b.getValue(i));
}
return c;
}
/*========================================
* Binary Vector operations (partial)
*========================================*/
/**
* Vector sum (a + b) for indices only in the given indices.
*
* @param b The vector to be added to this vector.
* @param indexList The list of indices to be applied summation.
* @return The resulting vector after summation.
*/
public SparseVector partPlus(SparseVector b, int[] indexList) {
if (indexList == null)
return this;
if (this.N != b.N)
throw new RuntimeException("Vector lengths disagree");
for (int i : indexList)
this.setValue(i, this.getValue(i) + b.getValue(i)); // c = c + b
return this;
}
/**
* Vector subtraction (a - b) for indices only in the given indices.
*
* @param b The vector to be subtracted from this vector.
* @param indexList The list of indices to be applied subtraction.
* @return The resulting vector after subtraction.
*/
public SparseVector partMinus(SparseVector b, int[] indexList) {
if (indexList == null)
return this;
if (this.N != b.N)
throw new RuntimeException("Vector lengths disagree");
for (int i : indexList)
this.setValue(i, this.getValue(i) - b.getValue(i)); // c = c - b
return this;
}
/**
* Inner-product for indices only in the given indices.
*
* @param b The vector to be inner-producted with this vector.
* @param indexList The list of indices to be applied inner-product.
* @return The inner-product value.
*/
public double partInnerProduct(SparseVector b, int[] indexList) {
double sum = 0.0;
if (indexList != null) {
for (int i : indexList) {
sum += this.getValue(i) * b.getValue(i);
}
}
return sum;
}
/**
* Outer-product for indices only in the given indices.
*
* @param b The vector to be outer-producted with this vector.
* @param indexList The list of indices to be applied outer-product.
* @return The outer-product value.
*/
public SparseMatrix partOuterProduct(SparseVector b, int[] indexList) {
if (indexList == null)
return null;
SparseMatrix A = new SparseMatrix(b.length(), b.length());
for (int i : indexList) {
for (int j : indexList) {
A.setValue(i, j, this.getValue(i) * b.getValue(j));
}
}
return A;
}
/**
* Get the topK indices with largest values.
* @param topK
* @param igonoreIndices Indices to ignore.
* @return
*/
public ArrayList topIndicesByValue(int topK, ArrayList ignoreIndices) {
HashMap hashmap = new HashMap();
for (int j : this.indexList()) {
hashmap.put(j, this.getValue(j));
}
return CommonUtils.TopKeysByValue(hashmap, topK, ignoreIndices);
}
/**
* Convert the vector to a printable string.
*
* @return The resulted string in the form of "(1: 5.0) (2: 4.5)"
*/
@Override
public String toString() {
String s = "";
for (int i : this.map) {
s += String.format("(%d:\t%.6f) ", i, map.get(i));
// s += "(" + i + ": " + map.get(i) + ") ";
}
return s;
}
public String KeysToString() {
String s = "[";
for (int i : this.map) {
s += i + ", ";
}
s += "]";
return s;
}
}
================================================
FILE: src/main/main.java
================================================
package main;
import java.io.BufferedReader;
import java.io.FileInputStream;
import java.io.FileNotFoundException;
import java.io.FileOutputStream;
import java.io.IOException;
import java.io.InputStreamReader;
import java.io.PrintWriter;
import java.util.Arrays;
import java.util.Collection;
import java.util.Collections;
import java.util.Comparator;
import java.util.HashMap;
import java.util.HashSet;
import java.util.List;
import java.util.NavigableMap;
import java.util.SortedSet;
import java.util.TreeSet;
import algorithms.*;
import utils.DatasetUtil;
import data_structure.DenseVector;
import data_structure.Rating;
import data_structure.SparseMatrix;
import data_structure.SparseVector;
import utils.Printer;
import utils.CommonUtils;
import java.util.ArrayList;
/**
* This is an abstract class for evaluating topK recommender systems (i.e. main functions.).
* Define some variables to use, and member functions to load data.
*
* @author HeXiangnan
* @since 2014.12.16
*/
public abstract class main {
/** Rating matrix for training. */
public static SparseMatrix trainMatrix;
/** Test ratings (sorted by time for global split). */
public static ArrayList testRatings;
public static int topK = 100;
public static int threadNum = 10;
public static int userCount;
public static int itemCount;
public static void ReadRatings_GlobalSplit(String ratingFile, double testRatio)
throws IOException {
userCount = itemCount = 0;
System.out.println("Global splitting with testRatio " + testRatio);
// Step 1. Construct data structure for sorting.
System.out.print("Read ratings and sort.");
long startTime = System.currentTimeMillis();
ArrayList ratings = new ArrayList();
BufferedReader reader = new BufferedReader(
new InputStreamReader(new FileInputStream(ratingFile)));
String line;
while((line = reader.readLine()) != null) {
Rating rating = new Rating(line);
ratings.add(rating);
userCount = Math.max(userCount, rating.userId);
itemCount = Math.max(itemCount, rating.itemId);
}
reader.close();
userCount ++;
itemCount ++;
// Step 2. Sort the ratings by time (small->large).
Comparator c = new Comparator() {
public int compare(Rating o1, Rating o2) {
if (o1.timestamp - o2.timestamp > 0) return 1;
else if (o1.timestamp - o2.timestamp < 0) return -1;
else return 0;
}
};
Collections.sort(ratings, c);
System.out.printf("[%s]\n", Printer.printTime(
System.currentTimeMillis() - startTime));
// Step 3. Generate trainMatrix and testStream
System.out.printf("Generate trainMatrix and testStream.");
startTime = System.currentTimeMillis();
trainMatrix = new SparseMatrix(userCount, itemCount);
testRatings = new ArrayList();
int testCount = (int) (ratings.size() * testRatio);
int count = 0;
for (Rating rating : ratings) {
if (count < ratings.size() - testCount) { // train
trainMatrix.setValue(rating.userId, rating.itemId, 1);
} else { // test
testRatings.add(rating);
}
count ++;
}
// Count number of new users/items/ratings in the test data
HashSet newUsers = new HashSet();
int newRatings = 0;
for (int u = 0; u < userCount; u ++) {
if (trainMatrix.getRowRef(u).itemCount() == 0) newUsers.add(u);
}
for (Rating rating : testRatings) {
if (newUsers.contains(rating.userId)) newRatings ++;
}
System.out.printf("[%s]\n", Printer.printTime(
System.currentTimeMillis() - startTime));
// Print some basic statistics of the dataset.
System.out.println ("Data\t" + ratingFile);
System.out.println ("#Users\t" + userCount + ", #newUser: " + newUsers.size());
System.out.println ("#Items\t" + itemCount);
System.out.printf("#Ratings\t %d (train), %d(test), %d(#newTestRatings)\n",
trainMatrix.itemCount(), testRatings.size(), newRatings);
}
/**
* Each line of .rating file is: userID\t itemID\t score\t timestamp.
* userID starts from 0 to num_user-1
* The items of each user is sorted by time (small->large).
*/
public static void ReadRatings_HoldOneOut(String ratingFile) throws IOException {
userCount = itemCount = 0;
System.out.println("HoldOne out splitting.");
// Step 1. Construct data structure for sorting.
System.out.print("Sort items for each user.");
long startTime = System.currentTimeMillis();
ArrayList> user_ratings = new ArrayList>();
BufferedReader reader = new BufferedReader(
new InputStreamReader(new FileInputStream(ratingFile)));
String line;
while((line = reader.readLine()) != null) {
Rating rating = new Rating(line);
if (user_ratings.size() - 1 < rating.userId) { // create a new user
user_ratings.add(new ArrayList());
}
user_ratings.get(rating.userId).add(rating);
userCount = Math.max(userCount, rating.userId);
itemCount = Math.max(itemCount, rating.itemId);
}
reader.close();
userCount ++;
itemCount ++;
assert userCount == user_ratings.size();
// Step 2. Sort the ratings of each user by time (small->large).
Comparator c = new Comparator() {
public int compare(Rating o1, Rating o2) {
if (o1.timestamp - o2.timestamp > 0) return 1;
else if (o1.timestamp - o2.timestamp < 0) return -1;
else return 0;
}
};
for (int u = 0; u < userCount; u ++) {
Collections.sort(user_ratings.get(u), c);
}
System.out.printf("[%s]\n", Printer.printTime(
System.currentTimeMillis() - startTime));
// Step 3. Generated splitted matrices (implicit 0/1 settings).
System.out.printf("Generate rating matrices.");
startTime = System.currentTimeMillis();
trainMatrix = new SparseMatrix(userCount, itemCount);
testRatings = new ArrayList();
for (int u = 0; u < userCount; u ++) {
ArrayList ratings = user_ratings.get(u);
for (int i = ratings.size() - 1; i >= 0; i --) {
int userId = ratings.get(i).userId;
int itemId = ratings.get(i).itemId;
if (i == ratings.size() - 1) { // test
testRatings.add(ratings.get(i));
} else { // train
trainMatrix.setValue(userId, itemId, 1);
}
}
}
System.out.printf("[%s]\n", Printer.printTime(
System.currentTimeMillis() - startTime));
// Print some basic statistics of the dataset.
System.out.println ("Data\t" + ratingFile);
System.out.println ("#Users\t" + userCount);
System.out.println ("#Items\t" + itemCount);
System.out.printf("#Ratings\t %d (train), %d(test)\n",
trainMatrix.itemCount(), testRatings.size());
}
/**
* Generate a smaller dataset.
* @param threshold
* @throws IOException
*/
public static void FilterRatingsWithThreshold(String ratingFile,
int userThreshold, int itemThreshold) throws IOException {
ArrayList> user_ratings = new ArrayList>();
System.out.println("Filter dataset with #user/item >= " + itemThreshold +
" and #item/user >= " + userThreshold);
// Read user ratings.
BufferedReader reader = new BufferedReader(
new InputStreamReader(new FileInputStream(ratingFile)));
HashMap map_item_count = new HashMap();
String line;
while((line = reader.readLine()) != null) {
Rating rating = new Rating(line);
if (user_ratings.size() - 1 < rating.userId) { // create a new user
user_ratings.add(new ArrayList());
}
user_ratings.get(rating.userId).add(rating);
if (!map_item_count.containsKey(rating.itemId)) {
map_item_count.put(rating.itemId, 0);
}
map_item_count.put(rating.itemId, map_item_count.get(rating.itemId) + 1);
}
reader.close();
// User filtering & item filtering
PrintWriter writer = new PrintWriter (new FileOutputStream(
ratingFile + "_i" + itemThreshold + "_u" + userThreshold));
HashMap map_user_id = new HashMap();
HashMap map_item_id = new HashMap();
int count = 0;
for (int u = 0; u < user_ratings.size(); u ++) {
ArrayList ratings = user_ratings.get(u);
int count_u = 0;
for (Rating rating : ratings) {
// item filtering
if (map_item_count.get(rating.itemId) < itemThreshold) continue;
count_u ++;
}
// user filtering
if (count_u < userThreshold) continue;
// write to files
for (Rating rating: ratings) {
if (map_item_count.get(rating.itemId) < itemThreshold) continue;
// Old item id and user id
String item = "" + rating.itemId;
String user = "" + rating.userId;
if (!map_item_id.containsKey(item)) {
map_item_id.put(item, map_item_id.size());
}
if (!map_user_id.containsKey(user)) {
map_user_id.put(user, map_user_id.size());
}
// New item id and user id
int userId = map_user_id.get(user);
int itemId = map_item_id.get(item);
writer.println(userId + "\t" + itemId + "\t" + rating.score + "\t" + rating.timestamp);
count ++;
}
}
System.out.printf("After filtering: #user:%d, #item:%d, #rating:%d \n",
map_user_id.size(), map_item_id.size(), count);
writer.close();
}
// Get some statistics about the dataset, e.g. user distribution on items
public static void DatasetStatistics(String ratingFile) throws IOException {
BufferedReader reader = new BufferedReader(
new InputStreamReader(new FileInputStream(ratingFile)));
// Read user ratings
int ratingCount = 0;
ArrayList> user_ratings = new ArrayList>();
String line;
while ((line = reader.readLine()) != null) {
Rating rating = new Rating(line);
ratingCount ++;
if (user_ratings.size() - 1 < rating.userId) { // create a new user
user_ratings.add(new ArrayList());
}
user_ratings.get(rating.userId).add(rating);
}
System.out.println("#Ratings in total: " + ratingCount);
// user distribution on items
HashMap map_count_users = new HashMap();
for (ArrayList ratings : user_ratings) {
int count = ratings.size();
if (!map_count_users.containsKey(ratings.size())) {
map_count_users.put(count, 0);
}
map_count_users.put(count, map_count_users.get(count) + 1);
}
List sortedKeys=new ArrayList(map_count_users.keySet());
Collections.sort(sortedKeys);
System.out.println("#rating\t#users (percentage)");
for (int count : sortedKeys) {
int users = map_count_users.get(count);
System.out.printf("%d\t %d (%.2f%%)\n", count, users,
(double)users / user_ratings.size() * 100 );
}
reader.close();
// Read item ratings
reader = new BufferedReader(new InputStreamReader(new FileInputStream(ratingFile)));
ArrayList> item_ratings = new ArrayList>();
while ((line = reader.readLine()) != null) {
Rating rating = new Rating(line);
if (item_ratings.size() - 1 < rating.itemId) { // create a new user
item_ratings.add(new ArrayList());
}
item_ratings.get(rating.itemId).add(rating);
}
// item distrubution on users
HashMap map_count_items = new HashMap();
for (ArrayList ratings : item_ratings) {
int count = ratings.size();
if (!map_count_items.containsKey(ratings.size())) {
map_count_items.put(count, 0);
}
map_count_items.put(count, map_count_items.get(count) + 1);
}
sortedKeys=new ArrayList(map_count_items.keySet());
Collections.sort(sortedKeys);
System.out.println("#rating\t#items (percentage)");
for (int count : sortedKeys) {
int items = map_count_items.get(count);
System.out.printf("%d\t %d (%.2f%%)\n", count, items,
(double)items / item_ratings.size() * 100 );
}
reader.close();
}
// Convert the movie-len-10M input(.dat) file to rating file.
public static void convertMLDatToRating(String ml_file) throws IOException {
BufferedReader reader = new BufferedReader(
new InputStreamReader(new FileInputStream(ml_file)));
PrintWriter writer = new PrintWriter (new FileOutputStream(ml_file + ".rating"));
int ratingCount = 0;
String splitter = "::";
HashMap map_item_id = new HashMap(); // id starts from 0
HashMap map_user_id = new HashMap();
String line;
while ((line = reader.readLine()) != null) {
String[] arr = line.split(splitter);
if (!map_user_id.containsKey(arr[0]))
map_user_id.put(arr[0], map_user_id.size());
if (!map_item_id.containsKey(arr[1]))
map_item_id.put(arr[1], map_item_id.size());
int userId = map_user_id.get(arr[0]);
int itemId = map_item_id.get(arr[1]);
writer.println(userId + "\t" + itemId + "\t" + arr[2] + "\t" + arr[3]);
ratingCount ++;
}
System.out.println("Converted " + ml_file + " to .rating file");
System.out.printf("#rating:%d, #user:%d, #item:%d \n",
ratingCount, map_user_id.size(), map_item_id.size());
reader.close();
writer.close();
}
// Convert the amazon review dataset (.vote) file to rating file.
public static void convertVoteToRating(String vote_file) throws IOException {
BufferedReader reader = new BufferedReader(
new InputStreamReader(new FileInputStream(vote_file)));
PrintWriter writer = new PrintWriter (new FileOutputStream(vote_file + ".rating"));
int ratingCount = 0;
String splitter = " ";
HashMap map_item_id = new HashMap(); // id starts from 0
HashMap map_user_id = new HashMap();
String line;
while ((line = reader.readLine()) != null) {
String[] arr = line.split(splitter);
if (!map_user_id.containsKey(arr[0]))
map_user_id.put(arr[0], map_user_id.size());
if (!map_item_id.containsKey(arr[1]))
map_item_id.put(arr[1], map_item_id.size());
int userId = map_user_id.get(arr[0]);
int itemId = map_item_id.get(arr[1]);
writer.println(userId + "\t" + itemId + "\t" + arr[2] + "\t" + arr[3]);
ratingCount ++;
}
System.out.println("Converted " + vote_file + " to .rating file");
System.out.printf("#rating:%d, #user:%d, #item:%d \n",
ratingCount, map_user_id.size(), map_item_id.size());
reader.close();
writer.close();
}
// Deduplicate the rating file by averaging the ratings for a (u,i) pair
// Note: after deduplication, timestamp is removed.
public static void deduplicate(String ratingFile) throws IOException {
// Read user ratings.
BufferedReader reader = new BufferedReader(
new InputStreamReader(new FileInputStream(ratingFile)));
int ratingCount = 0;
ArrayList> user_ratings = new ArrayList>();
String line;
while ((line = reader.readLine()) != null) {
Rating rating = new Rating(line);
ratingCount ++;
if (user_ratings.size() - 1 < rating.userId) { // create a new user
user_ratings.add(new ArrayList());
}
user_ratings.get(rating.userId).add(rating);
}
System.out.println("#Ratings in total: " + ratingCount);
reader.close();
// Deduplicate and Writing to file
PrintWriter writer = new PrintWriter (new FileOutputStream(ratingFile + ".deduplicate"));
ratingCount = 0;
for (int u = 0; u < user_ratings.size(); u ++) {
ArrayList ratings = user_ratings.get(u);
HashMap map_item_score = new HashMap();
HashMap map_item_count = new HashMap();
for (Rating rating: ratings) {
if (!map_item_score.containsKey(rating.itemId)) {
map_item_score.put(rating.itemId, 0.0);
map_item_count.put(rating.itemId, 0);
}
map_item_score.put(rating.itemId, map_item_score.get(rating.itemId) + rating.score);
map_item_count.put(rating.itemId, map_item_count.get(rating.itemId) + 1);
}
for (int i : map_item_score.keySet()) {
double score = map_item_score.get(i) / map_item_count.get(i);
writer.printf("%d\t%d\t%.1f\n", u+1, i+1, score);
ratingCount ++;
}
}
writer.close();
System.out.println("#After dedepulicate, #ratings: " + ratingCount);
}
public static void main(String[] args) throws IOException {
String dataset ="hanwang-data/amazon_books_filter.rating";
deduplicate(dataset);
//String dataset = "data/yelp.rating";
//ReadRatings_HoldOneOut("data/yelp.rating");
//FilterRatingsWithThreshold(dataset, 10, 10);
//DatasetStatistics(dataset);
//convertVoteToRating(dataset);
//FilterRatingsWithThreshold(dataset, 10, 10);
}
// Evaluate the model
public static double[] evaluate_model(TopKRecommender model, String name) {
long start = System.currentTimeMillis();
model.buildModel();
model.evaluate(testRatings);
double[] res = new double[3];
res[0] = model.hits.mean();
res[1] = model.ndcgs.mean();
res[2] = model.precs.mean();
System.out.printf("%s\t
:\t %.4f\t %.4f\t %.4f [%s]\n",
name, res[0], res[1], res[2],
Printer.printTime(System.currentTimeMillis() - start));
return res;
}
// Evaluate the model by online protocol
public static void evaluate_model_online(TopKRecommender model, String name, int interval) {
long start = System.currentTimeMillis();
model.evaluateOnline(testRatings, interval);
System.out.printf("%s\t
:\t %.4f\t %.4f\t %.4f [%s]\n",
name, model.hits.mean(), model.ndcgs.mean(), model.precs.mean(),
Printer.printTime(System.currentTimeMillis() - start));
}
}
class ModelThread extends Thread {
TopKRecommender model;
public ModelThread(TopKRecommender model) {
this.model = model;
}
public void run() {
model.runOneIteration();
}
}
================================================
FILE: src/main/main_MF.java
================================================
package main;
import java.io.IOException;
import data_structure.DenseMatrix;
import utils.Printer;
import algorithms.MF_fastALS;
import algorithms.MF_ALS;
import algorithms.MF_CD;
import algorithms.ItemPopularity;
public class main_MF extends main {
public static void main(String argv[]) throws IOException {
String dataset_name = "yelp";
String method = "FastALS";
double w0 = 10;
boolean showProgress = false;
boolean showLoss = true;
int factors = 64;
int maxIter = 500;
double reg = 0.01;
double alpha = 0.75;
if (argv.length > 0) {
dataset_name = argv[0];
method = argv[1];
w0 = Double.parseDouble(argv[2]);
showProgress = Boolean.parseBoolean(argv[3]);
showLoss = Boolean.parseBoolean(argv[4]);
factors = Integer.parseInt(argv[5]);
maxIter = Integer.parseInt(argv[6]);
reg = Double.parseDouble(argv[7]);
if (argv.length > 8) alpha = Double.parseDouble(argv[8]);
}
//ReadRatings_GlobalSplit("data/" + dataset_name + ".rating", 0.1);
ReadRatings_HoldOneOut("data/" + dataset_name + ".rating");
System.out.printf("%s: showProgress=%s, factors=%d, maxIter=%d, reg=%f, w0=%.2f, alpha=%.2f\n",
method, showProgress, factors, maxIter, reg, w0, alpha);
System.out.println("====================================================");
ItemPopularity popularity = new ItemPopularity(trainMatrix, testRatings, topK, threadNum);
evaluate_model(popularity, "Popularity");
double init_mean = 0;
double init_stdev = 0.01;
if (method.equalsIgnoreCase("fastals")) {
MF_fastALS fals = new MF_fastALS(trainMatrix, testRatings, topK, threadNum,
factors, maxIter, w0, alpha, reg, init_mean, init_stdev, showProgress, showLoss);
evaluate_model(fals, "MF_fastALS");
}
if (method.equalsIgnoreCase("als")) {
MF_ALS als = new MF_ALS(trainMatrix, testRatings, topK, threadNum,
factors, maxIter, w0, reg, init_mean, init_stdev, showProgress, showLoss);
evaluate_model(als, "MF_ALS");
}
if (method.equalsIgnoreCase("cd")) {
MF_CD cd = new MF_CD(trainMatrix, testRatings, topK, threadNum,
factors, maxIter, w0, reg, init_mean, init_stdev, showProgress, showLoss);
evaluate_model(cd, "MF_CD");
}
if (method.equalsIgnoreCase("all")) {
DenseMatrix U = new DenseMatrix(userCount, factors);
DenseMatrix V = new DenseMatrix(itemCount, factors);
U.init(init_mean, init_stdev);
V.init(init_mean, init_stdev);
MF_fastALS fals = new MF_fastALS(trainMatrix, testRatings, topK, threadNum,
factors, maxIter, w0, alpha, reg, init_mean, init_stdev, showProgress, showLoss);
fals.setUV(U, V);
evaluate_model(fals, "MF_fastALS");
MF_ALS als = new MF_ALS(trainMatrix, testRatings, topK, threadNum,
factors, maxIter, w0, reg, init_mean, init_stdev, showProgress, showLoss);
als.setUV(U, V);
evaluate_model(als, "MF_ALS");
MF_CD cd = new MF_CD(trainMatrix, testRatings, topK, threadNum,
factors, maxIter, w0, reg, init_mean, init_stdev, showProgress, showLoss);
cd.setUV(U, V);
evaluate_model(cd, "MF_CD");
}
} // end main
}
================================================
FILE: src/main/main_bpr.java
================================================
package main;
import java.io.IOException;
import utils.Printer;
import algorithms.MFbpr;
import algorithms.ItemPopularity;
import algorithms.TopKRecommender;
import data_structure.Rating;
import java.util.ArrayList;
public class main_bpr extends main {
public static void main(String argv[]) throws IOException {
String dataset_name = "yelp";
int factors = 16;
double lr = 0.01;
double reg = 0.01;
int num_dns = 1; // number of dynamic negative samples [Zhang Weinan et al. SIGIR 2013]
int maxIter = 1000;
double init_mean = 0;
double init_stdev = 0.01;
if (argv.length > 0) {
dataset_name = argv[0];
factors = Integer.parseInt(argv[1]);
lr = Double.parseDouble(argv[2]);
reg = Double.parseDouble(argv[3]);
}
ReadRatings_HoldOneOut("data/" + dataset_name + ".rating");
topK = 100;
System.out.printf("BPR with factors=%d, lr=%.4f, reg=%.4f, num_dns=%d\n",
factors, lr, reg, num_dns);
System.out.println("====================================================");
ItemPopularity pop = new ItemPopularity(trainMatrix, testRatings, topK, threadNum);
evaluate_model(pop, "Popularity");
MFbpr bpr = new MFbpr(trainMatrix, testRatings, topK, threadNum,
factors, maxIter, lr, false, reg, init_mean, init_stdev, num_dns, true);
evaluate_model(bpr, "BPR");
} // end main
}
================================================
FILE: src/main/main_online.java
================================================
package main;
import java.io.IOException;
import data_structure.DenseMatrix;
import utils.Printer;
import algorithms.MF_fastALS;
import algorithms.MF_ALS;
import algorithms.MF_CD;
import algorithms.ItemPopularity;
import algorithms.MFbpr;
public class main_online extends main {
public static void main(String argv[]) throws IOException {
String dataset_name = "yelp";
String method = "FastALS";
int interval = 1000;
double w0 = 512;
int factors = 64;
int maxIter = 50;
int maxIterOnline = 1;
double alpha = 0.4;
String onlineMode = "ui";
double w_new = 1;
if (argv.length > 0) {
dataset_name = argv[0];
method = argv[1];
interval = Integer.parseInt(argv[2]);
w0 = Double.parseDouble(argv[3]);
factors = Integer.parseInt(argv[4]);
maxIter = Integer.parseInt(argv[5]);
maxIterOnline = Integer.parseInt(argv[6]);
alpha = Double.parseDouble(argv[7]);
if (argv.length >= 9) onlineMode = argv[8];
if (argv.length >= 10) w_new = Double.parseDouble(argv[9]);
}
ReadRatings_GlobalSplit("data/" + dataset_name + ".rating", 0.1);
System.out.printf("Online evaluation for %s: factors=%d, maxIter=%d, maxInterOnline=%d, interval=%d, onlineMode(bpr only)=%s\n",
method, factors, maxIter, maxIterOnline, interval, onlineMode);
System.out.println("====================================================");
ItemPopularity popularity = new ItemPopularity(trainMatrix, testRatings, topK, threadNum);
evaluate_model_online(popularity, "Popularity", interval);
double init_mean = 0;
double init_stdev = 0.01;
double reg = 0.01;
boolean showProgress = false;
boolean showLoss = false;
// Remove ALS is not suitable for online learning.
if (method.equalsIgnoreCase("als")) {
MF_ALS als = new MF_ALS(trainMatrix, testRatings, topK, threadNum,
factors, maxIter, w0, reg, init_mean, init_stdev, showProgress, showLoss);
als.buildModel();
als.maxIterOnline = maxIterOnline;
evaluate_model_online(als, "MF_ALS", interval);
}
if (method.equalsIgnoreCase("fastals")) {
MF_fastALS fals = new MF_fastALS(trainMatrix, testRatings, topK, threadNum,
factors, maxIter, w0, alpha, reg, init_mean, init_stdev, showProgress, showLoss);
fals.w_new = w_new;
fals.buildModel();
fals.maxIterOnline = maxIterOnline;
evaluate_model_online(fals, "MF_fastALS", interval);
}
if (method.equalsIgnoreCase("cd")) {
MF_CD cd = new MF_CD(trainMatrix, testRatings, topK, threadNum,
factors, maxIter, w0, reg, init_mean, init_stdev, showProgress, showLoss);
cd.w_new = w_new;
cd.buildModel();
cd.maxIterOnline = maxIterOnline;
evaluate_model_online(cd, "MF_CD", interval);
}
if (method.equalsIgnoreCase("bpr")) {
MFbpr bpr = new MFbpr(trainMatrix, testRatings, topK, threadNum,
factors, maxIter, 0.01, false, reg, init_mean, init_stdev, 1, showProgress);
bpr.onlineMode = onlineMode;
bpr.buildModel();
bpr.maxIterOnline = maxIterOnline;
evaluate_model_online(bpr, "BPR", interval);
}
} // end main
}
================================================
FILE: src/utils/CommonUtils.java
================================================
package utils;
import java.io.IOException;
import java.util.ArrayList;
import java.util.Collections;
import java.util.Comparator;
import java.util.HashSet;
import java.util.List;
import java.util.Map;
import java.util.PriorityQueue;
import java.util.HashMap;
import java.util.Random;
public class CommonUtils {
/**
* Sort the HashMap by its values, from Large->Small.
* @return List> with sorted entries.
*/
public static> List> SortMapByValue(Map map) {
List> infoIds = new ArrayList>(map.entrySet());
Comparator> c = new Comparator>() {
public int compare(Map.Entry o1, Map.Entry o2) {
return o2.getValue().compareTo(o1.getValue());
}};
Collections.sort(infoIds, c);
return infoIds;
}
/**
* Get the topK keys (by its value) of a map. Does not consider the keys which are in ignoreKeys.
* @param map
* @return
*/
public static> ArrayList TopKeysByValue(Map map,
int topK, ArrayList ignoreKeys) {
HashSet ignoreSet;
if (ignoreKeys == null) {
ignoreSet = new HashSet();
} else {
ignoreSet = new HashSet (ignoreKeys);
}
TopKPriorityQueue topQueue = new TopKPriorityQueue(topK);
for (Map.Entry entry : map.entrySet()) {
if (!ignoreSet.contains(entry.getKey())) {
topQueue.add(entry);
}
}
ArrayList topKeys = new ArrayList();
for (Map.Entry entry : topQueue.sortedList()) {
topKeys.add(entry.getKey());
}
return topKeys;
/*
//Another implementation that first sorting.
List> topEntities = SortMapByValue(map);
ArrayList topKeys = new ArrayList();
for (Map.Entry entity : topEntities) {
if (topKeys.size() >= topK) break;
if (!ignoreSet.contains(entity.getKey())) {
topKeys.add(entity.getKey());
}
}
return topKeys; */
}
/**
* Convert an int[] to ArrayList
*/
public static ArrayList ArrayToArraylist(int[] array) {
if (array == null) {
return new ArrayList();
}
ArrayList list = new ArrayList(array.length);
for (int val : array) {
list.add(val);
}
return list;
}
/**
* Count number of matches of findStr in str.
* @param str
* @param findStr
* @return
*/
public static int CountMatchesInString(String str, String findStr) {
int lastIndex = 0;
int count = 0;
while(lastIndex != -1) {
lastIndex = str.indexOf(findStr,lastIndex);
if( lastIndex != -1) {
count ++;
lastIndex+=findStr.length();
}
}
return count;
}
/**
* Convert a string to k-gram set.
* @param str
* @param size
*/
public static ArrayList StringToGramSet(String str, int k) {
ArrayList grams = new ArrayList();
String[] words = str.split(" ");
for(int i = 0; i <= words.length-k; i ++) {
String gram = words[i];
for (int j = 1; j < k; j++) {
gram += " " + words[i+j];
}
grams.add(gram.trim());
}
return grams;
}
/**
* Randomly shuffle an int array.
* @param array
*/
public static void ShuffleArray(int[] array)
{
int index, temp;
Random random = new Random();
for (int i = array.length - 1; i > 0; i--)
{
index = random.nextInt(i + 1);
temp = array[index];
array[index] = array[i];
array[i] = temp;
}
}
}
================================================
FILE: src/utils/DatasetUtil.java
================================================
package utils;
import java.io.BufferedReader;
import java.io.FileInputStream;
import java.io.FileNotFoundException;
import java.io.FileOutputStream;
import java.io.IOException;
import java.io.InputStreamReader;
import java.io.PrintWriter;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.Collections;
import java.util.Comparator;
import java.util.HashMap;
import java.util.HashSet;
import java.util.LinkedList;
import java.util.List;
import java.util.Map;
import java.util.Map.Entry;
import java.util.Set;
import org.json.simple.JSONObject;
import org.json.simple.parser.ParseException;
import org.json.simple.parser.JSONParser;
import java.text.DateFormat;
import java.text.SimpleDateFormat;
import data_structure.SparseMatrix;
import utils.StopwordsFilter;
import data_structure.SparseVector;
/**
* Represent each review.
* @author HeXiangnan
*
*/
class Vote {
public String user;
public String item;
public double rating;
public Integer time;
public int wordCount;
public String review;
public Vote(String user, String item, double rating, int time, int wordCount, String review) {
this.user = user;
this.item = item;
this.rating = rating;
this.time = time;
this.wordCount = wordCount;
this.review = review;
}
/**
* Sort votes by the review time, small (old) -> large (recent)
* @param votes
* @return
*/
public static void sortByTime(ArrayList votes) {
Comparator comparator = new Comparator () {
public int compare(Vote vote0, Vote vote1) {
return vote0.time.compareTo(vote1.time);
}
};
Collections.sort(votes, comparator);
}
@Override
public String toString() {
String line = String.format("%s %s %.1f %d %d %s", user, item, rating, time, wordCount, review);
return line;
}
}
public class DatasetUtil {
private BufferedReader reader;
public DatasetUtil() {
}
/*==============================================================================================
* Process datasets, e.g. converting to .votes file,
* splitting(train, test, validation) and filtering dataset.
*==============================================================================================*/
/**
* Convert the original Amazon datasets into votes file (originally provided by HFT, Recsys'13 paper)
* Input file format example:
* amazon_datasets/arts.txt
* Output file format:
* A list of quadruple of form (userID, itemID, rating, time), followed by #words of the review,
* followed by the words themselves (lower-cased).
* See example of amazon_datasets/arts.votes
* @param inputfileDir Directory of input dataset.
* @param dataset Dataset name.
* @throws IOException
*/
public void ConvertTxtToVotesFile(String inputfileDir, String dataset)
throws IOException {
String inputfileName = inputfileDir + dataset + ".txt";
String outputfileName = inputfileDir + dataset + ".votes";
System.out.println("\nConverting to .votes file: " + inputfileName);
reader = new BufferedReader(new InputStreamReader(new FileInputStream(inputfileName)));
PrintWriter writer = new PrintWriter (new FileOutputStream(outputfileName));
String line;
String productId="", userId="", rating="", time="";
int count = 0;
while((line = reader.readLine()) != null) {
if (line.contains(":")) {
String[] segments = line.split(":");
String linename = segments[0].trim();
if (linename.equals("product/productId")) {
productId = segments[1].trim();
}
if (linename.equals("review/userId")) {
userId = segments[1].trim();
}
if (linename.equals("review/score")) {
rating = segments[1].trim();
}
if (linename.equals("review/time")) {
time = segments[1].trim();
}
if (linename.equals("review/text")) {
String review_text = segments[1].trim();
String parse_review_text = "";
/*String[] review_words = parseSentence(review_text);
for (String review_word : review_words) {
review_word = review_word.toLowerCase();
parse_review_text = parse_review_text + review_word + " ";
}*/
// Output to the votes file.
writer.println(userId + " " + productId + " " + rating + " " +
time + " " + review_text.split(" ").length + " " + review_text);
productId = userId = rating = time = "";
}
}
if (count++ % 10000 == 0)
System.out.print(".");
}
reader.close();
writer.close();
}
/**
* Convert the original Yelp Challenge datasets into votes file.
* Input file format example:
* yelp_datasets/yelp_reviews_220K.json
* Output file format:
* A list of quadruple of form (userID, itemID, rating, time), followed by #words of the review,
* followed by the words themselves (lower-cased).
* See example of amazon_datasets/arts.votes
* @param inputfileDir
* @param dataset
* @throws IOException
* @throws ParseException
* @throws java.text.ParseException
*/
public void ConvertJsonToVotesFile(String inputfileDir, String dataset) throws IOException, ParseException, java.text.ParseException {
String inputfileName = inputfileDir + dataset + ".json";
String outputfileName = inputfileDir + dataset + ".votes";
System.out.println("\nConverting to .votes file: " + inputfileName);
BufferedReader reader = new BufferedReader(new InputStreamReader(new FileInputStream(inputfileName)));
PrintWriter writer = new PrintWriter (new FileOutputStream(outputfileName));
String line;
JSONParser parser=new JSONParser();
int count = 0;
while ((line = reader.readLine()) != null) {
JSONObject obj = (JSONObject) parser.parse(line);
String user_id = (String) obj.get("user_id");
String business_id = (String) obj.get("business_id");
String score = (Long) obj.get("stars") + ".0";
// Parse time to unix time.
String date = (String) obj.get("date");
String time = date.replace("-", "") + "0800";
DateFormat dfm = new SimpleDateFormat("yyyyMMddHHmm");
Long unixtime = dfm.parse(time).getTime() / 1000;
String review_text = (String) obj.get("text");
review_text = review_text.replace("|", " ").replace("\n", " ");
// Parse review words.
String[] review_words = parseSentence((String) obj.get("text"));
String parse_review_text = "";
for (String review_word : review_words) {
parse_review_text = parse_review_text + review_word.toLowerCase() + " ";
}
// Output to the .votes file.
writer.println(user_id + " " + business_id + " " + score + " " +
unixtime + " " + review_words.length + " " + parse_review_text);
//writer.println(user_id + "|" + business_id + "|" + score + "|" +
//unixtime + "|" + review_text);
if (count++ % 10000 == 0)
System.out.print(".");
}
System.out.println("#reviews: " + count);
reader.close();
writer.close();
}
/**
* Convert the original Yelp Challenge datasets into .raw file for lexicon construction.
* The .raw data is used by the tool thuir-sentires.jar.
* The format is review_text
* @param inputfileDir
* @param dataset
* @throws IOException
* @throws ParseException
* @throws java.text.ParseException
*/
public void ConvertJsonToRawFile(String inputfileDir, String dataset) throws IOException, ParseException, java.text.ParseException {
String inputfileName = inputfileDir + dataset + ".json";
String outputfileName = inputfileDir + dataset + ".raw";
System.out.println("\nConverting to .raw file: " + inputfileName);
reader = new BufferedReader(new InputStreamReader(new FileInputStream(inputfileName)));
PrintWriter writer = new PrintWriter (new FileOutputStream(outputfileName));
String line;
JSONParser parser=new JSONParser();
int count = 0;
while ((line = reader.readLine()) != null) {
JSONObject obj = (JSONObject) parser.parse(line);
// Parse review words.
String review = (String) obj.get("text");
// Output to the .raw file.
writer.println("");
writer.println(review);
writer.println("");
if (count++ % 10000 == 0)
System.out.print(".");
}
System.out.println("\nGenerated .raw file" + outputfileName);
reader.close();
writer.close();
}
/**
* Format of .rating file:
* Each line is: user_id\t item_id\t ratingScore
* @param inputfileDir
* @param dataset
* @throws IOException
*/
public void ConvertVotesToRatingFile(String inputfileDir, String dataset) throws IOException {
String inputfileName = inputfileDir + dataset + ".votes";
String outputfileName = inputfileDir + dataset + ".rating";
System.out.println("\nConverting .votes to .rating file: " + inputfileName);
reader = new BufferedReader(new InputStreamReader(new FileInputStream(inputfileName)));
PrintWriter writer = new PrintWriter (new FileOutputStream(outputfileName));
String line;
while ((line = reader.readLine()) != null) {
String user = parseVotesLine(line).user;
String item = parseVotesLine(line).item;
double rating = parseVotesLine(line).rating;
// Output to the .raw file.
writer.printf("%s\t%s\t%f\n",user,item,rating);
}
System.out.println("Generated .rating file" + outputfileName);
reader.close();
writer.close();
}
public void ConvertVotesToRawFile(String inputfileDir, String dataset) throws IOException {
String inputfileName = inputfileDir + dataset + ".votes";
String outputfileName = inputfileDir + dataset + ".raw";
System.out.println("\nConverting .votes to .raw file: " + inputfileName);
reader = new BufferedReader(new InputStreamReader(new FileInputStream(inputfileName)));
PrintWriter writer = new PrintWriter (new FileOutputStream(outputfileName));
String line;
int count = 0;
while ((line = reader.readLine()) != null) {
String review = parseVotesLine(line).review;
// Output to the .raw file.
writer.println("");
writer.println(review);
writer.println("");
if (count++ % 10000 == 0)
System.out.print(".");
}
System.out.println("\nGenerated .raw file" + outputfileName);
reader.close();
writer.close();
}
public void ConvertTxtToRawFile(String inputfileDir, String dataset) throws IOException {
String inputfileName = inputfileDir + dataset + ".txt";
String outputfileName = inputfileDir + dataset + ".raw";
System.out.println("\nConverting to .raw file: " + inputfileName);
reader = new BufferedReader(new InputStreamReader(new FileInputStream(inputfileName)));
PrintWriter writer = new PrintWriter (new FileOutputStream(outputfileName));
String line;
int count = 0;
while((line = reader.readLine()) != null) {
if (line.contains(":")) {
String[] segments = line.split(":");
String linename = segments[0].trim();
if (linename.equals("review/text")) {
String review = segments[1].trim();
// Output to the raw file.
writer.println("");
writer.println(review);
writer.println("");
}
if (count++ % 10000 == 0)
System.out.print(".");
}
}
reader.close();
writer.close();
}
/**
* If a user has rated an item multiple times, using the recent one.
* @param inputDir
* @param dataset
* @throws IOException
*/
public void RemoveDuplicateInVotesFile(String inputDir, String dataset) throws IOException {
String inputFile = inputDir + dataset +".votes";
reader = new BufferedReader(new InputStreamReader(new FileInputStream(inputFile)));
String outputFile = inputDir + dataset + ".votes.noDuplicate";
PrintWriter writer = new PrintWriter (new FileOutputStream(outputFile));
// Build map, where key is userID_itemID
HashMap> map = new HashMap>();
String line;
int count = 0;
while((line = reader.readLine()) != null) {
Vote vote = parseVotesLine(line);
String key = vote.user + "_" + vote.item;
if (!map.containsKey(key)) {
map.put(key, new ArrayList());
}
map.get(key).add(vote);
count ++;
}
// Write file.
for (Entry> it : map.entrySet()) {
ArrayList votes = it.getValue();
if(it.getValue().size() > 1) { // write the latest vote.
Vote.sortByTime(votes);
}
writer.println(votes.get(votes.size() - 1).toString());
}
System.out.printf("Before removing duplicates, #lines: %d, after: %d\n", count, map.size());
System.out.printf("Generated file: %s\n", outputFile);
reader.close();
writer.close();
}
/**
*
* @param inputfileDir
* @param K Number of test items to holdout for each user.
* @throws IOException
*/
public void SplitVotesFileRandomAllButK(String inputfileDir, String dataset, int K) throws IOException {
String inputfile = inputfileDir+"all/" + dataset + ".votes";
BufferedReader reader = new BufferedReader(new InputStreamReader(new FileInputStream(inputfile)));
System.out.printf("Spliting .votes file %s randomly All-But-%d\n", dataset, K);
// Step 1: Build votes dictionary of each user.
HashMap> user_votes = new HashMap>();
String line;
int numReviews = 0;
while ((line = reader.readLine()) != null ) {
Vote vote = parseVotesLine(line);
if (vote != null) {
if (!user_votes.containsKey(vote.user)) {
user_votes.put(vote.user, new ArrayList());
}
user_votes.get(vote.user).add(vote);
numReviews ++;
if (numReviews % 10000 == 0) System.out.print(".");
}
}
//System.out.print("\n\t #reviews: " + numReviews + ", #users: " + user_votes.size());
reader.close();
// Step 2: Write the train/valid/test file.
//System.out.print("\n 2nd Step: Writing train/validation/split files.");
reader = new BufferedReader(new InputStreamReader(new FileInputStream(inputfile)));
String outputfileTrain = inputfileDir + "train\\" + dataset + ".votes";
String outputfileValid = inputfileDir + "validation\\" + dataset + ".votes";
String outputfileTest = inputfileDir + "test\\" + dataset + ".votes";
PrintWriter writerTrain = new PrintWriter (new FileOutputStream(outputfileTrain));
PrintWriter writerValid = new PrintWriter (new FileOutputStream(outputfileValid));
PrintWriter writerTest = new PrintWriter (new FileOutputStream(outputfileTest));
int numTrain = 0, numValid = 0, numTest = 0;
for (String user : user_votes.keySet()) {
ArrayList votes = user_votes.get(user);
HashSet samples = new HashSet();
// Generate for test set and valid set first.
while (true) {
if (samples.size() >= 2*K) break;
int sample = (int) (votes.size() * Math.random());
if (!samples.contains(sample)) {
samples.add(sample);
if (samples.size() <= K) { // add to test.
writerTest.println(votes.get(sample));
numTest ++;
} else { // add to valid.
writerValid.println(votes.get(sample));
numValid ++;
}
}
}
// Add the remaining into training.
for (int i = 0; i < votes.size(); i++) {
if (!samples.contains(i)) {
writerTrain.println(votes.get(i));
numTrain ++;
}
}
}
//System.out.print("\n\t #train: " + numTrain + ", #valid: " + numValid + ", #test: " + numTest);
reader.close();
writerTrain.close();
writerValid.close();
writerTest.close();
/*System.out.print("\n Write splitted files into: \n");
System.out.println(outputfileTrain);
System.out.println(outputfileValid);
System.out.println(outputfileTest);*/
}
/**
*
* @param inputfileDir
* @param K Number of test/validation items to holdout for each user.
* @throws IOException
*/
public void SplitVotesFileByTimeAllButK(String inputfileDir, String dataset, int K) throws IOException {
String inputfile = inputfileDir + dataset + ".votes";
BufferedReader reader = new BufferedReader(new InputStreamReader(new FileInputStream(inputfile)));
System.out.printf("Spliting .votes file %s by time All-But-%d\n", dataset, K);
// Step 1: Build votes dictionary of each user.
HashMap> user_votes = new HashMap>();
String line;
int numReviews = 0;
while ((line = reader.readLine()) != null ) {
Vote vote = parseVotesLine(line);
if (vote != null) {
if (!user_votes.containsKey(vote.user)) {
user_votes.put(vote.user, new ArrayList());
}
user_votes.get(vote.user).add(vote);
numReviews ++;
if (numReviews % 10000 == 0) System.out.print(".");
}
}
//System.out.print("\n\t #reviews: " + numReviews + ", #users: " + user_votes.size());
reader.close();
// Step 2: Sort each user's votes.
//System.out.print("\n 2nd Step: Sort each user's votes.");
for (String user : user_votes.keySet()) {
Vote.sortByTime(user_votes.get(user));
}
// Step 3: Write the train/valid/test file.
//System.out.print("\n 3rd Step: Writing train/validation/split files.");
reader = new BufferedReader(new InputStreamReader(new FileInputStream(inputfile)));
String outputfileTrain = inputfileDir + "train/" + dataset + ".votes";
String outputfileValid = inputfileDir + "validation/" + dataset + ".votes";
String outputfileTest = inputfileDir + "test/" + dataset + ".votes";
PrintWriter writerTrain = new PrintWriter (new FileOutputStream(outputfileTrain));
PrintWriter writerValid = new PrintWriter (new FileOutputStream(outputfileValid));
PrintWriter writerTest = new PrintWriter (new FileOutputStream(outputfileTest));
int numTrain = 0, numValid = 0, numTest = 0;
for (String user : user_votes.keySet()) {
ArrayList votes = user_votes.get(user);
int trainCount = votes.size() - 2 * K;
int validCount = K;
int testCount = K;
for (int i = 0; i < votes.size(); i++) {
if (i < trainCount) {
writerTrain.println(votes.get(i));
} else if (i < trainCount + validCount) {
writerValid.println(votes.get(i));
} else {
writerTest.println(votes.get(i));
}
}
numTrain += trainCount;
numValid += validCount;
numTest += testCount;
}
System.out.print("\n\t #train: " + numTrain + ", #valid: " + numValid + ", #test: " + numTest);
reader.close();
writerTrain.close();
writerValid.close();
writerTest.close();
}
/**
* Split the .vote Review dataset by reviewing time on each user basis.
* For each user, first select the oldest reviews as train, then randomly split valid and test.
* If a user's review number (N) is less than 10, split as for > user_votes = new HashMap>();
String line;
int numReviews = 0;
while ((line = reader.readLine()) != null ) {
Vote vote = parseVotesLine(line);
if (vote != null) {
if (!user_votes.containsKey(vote.user)) {
user_votes.put(vote.user, new ArrayList());
}
user_votes.get(vote.user).add(vote);
numReviews ++;
if (numReviews % 10000 == 0) System.out.print(".");
}
}
System.out.print("\n\t #reviews: " + numReviews + ", #users: " + user_votes.size());
reader.close();
// Step 2: Sort each user's votes.
System.out.print("\n 2nd Step: Sort each user's votes.");
for (String user : user_votes.keySet()) {
Vote.sortByTime(user_votes.get(user));
}
// Step 3: Write the train/valid/test file.
System.out.print("\n 3rd Step: Writing train/validation/split files.");
reader = new BufferedReader(new InputStreamReader(new FileInputStream(inputfile)));
String outputfileTrain = inputfileDir + "train/" + dataset + ".votes";
String outputfileValid = inputfileDir + "validation/" + dataset + ".votes";
String outputfileTest = inputfileDir + "test/" + dataset + ".votes";
PrintWriter writerTrain = new PrintWriter (new FileOutputStream(outputfileTrain));
PrintWriter writerValid = new PrintWriter (new FileOutputStream(outputfileValid));
PrintWriter writerTest = new PrintWriter (new FileOutputStream(outputfileTest));
int numTrain = 0, numValid = 0, numTest = 0;
for (String user : user_votes.keySet()) {
ArrayList votes = user_votes.get(user);
int trainCount, validCount, testCount;
if (votes.size() < 3) {
trainCount = votes.size();
validCount = 0;
testCount = 0;
}
if (votes.size() < 10) {
trainCount = votes.size() - 2;
validCount = 1;
testCount = 1;
} else {
testCount = (int) (votes.size() * testRatio);
validCount = (int) (votes.size() * validRatio);
trainCount = votes.size() - testCount - validCount;
}
for (int i = 0; i < votes.size(); i++) {
if (i < trainCount) {
writerTrain.println(votes.get(i));
} else {
if (i < trainCount + validCount) writerValid.println(votes.get(i));
else writerTest.println(votes.get(i));
}
}
numTrain += trainCount;
numValid += validCount;
numTest += testCount;
}
System.out.print("\n\t #train: " + numTrain + ", #valid: " + numValid + ", #test: " + numTest);
reader.close();
writerTrain.close();
writerValid.close();
writerTest.close();
System.out.print("\n Write splitted files into: \n");
System.out.println(outputfileTrain);
System.out.println(outputfileValid);
System.out.println(outputfileTest);
}
/**
* Only retain users whose number of reviews is not in the range of [min_reviews, max_reviews]
* @param inputfileDir
* @param dataset
* @param min_reviews
* @param max_reviews
* @throws IOException
*/
public void FilterVotesFileByUsers(String inputfileDir, String dataset, int min_reviews, int max_reviews) throws IOException {
String inputfile = inputfileDir + dataset + ".votes";
String outputfile= inputfileDir + dataset + "_u" + min_reviews + "_" + max_reviews + ".votes";
reader = new BufferedReader(new InputStreamReader(new FileInputStream(inputfile)));
System.out.printf("Filtering reviews with range [%d, %d] reviews/user for %s \n",
min_reviews, max_reviews, dataset);
// Step 1: count how many reviews per user.
HashMap map_user_count = new HashMap();
String line;
while ((line = reader.readLine()) != null ) {
String user_id = line.split(" ")[0];
if (!map_user_count.containsKey(user_id)) {
map_user_count.put(user_id, 0);
}
map_user_count.put(user_id, map_user_count.get(user_id) + 1);
}
reader.close();
System.out.println("Before filtering, #users: " + map_user_count.size());
// Step 2: output the new filtered file.
reader = new BufferedReader(new InputStreamReader(new FileInputStream(inputfile)));
PrintWriter writer = new PrintWriter (new FileOutputStream(outputfile));
while ((line = reader.readLine())!= null) {
String user_id = line.split(" ")[0];
if (map_user_count.containsKey(user_id) && map_user_count.get(user_id) >= min_reviews &&
map_user_count.get(user_id) <= max_reviews) {
writer.println(line);
} else {
map_user_count.remove(user_id);
}
}
reader.close();
writer.close();
System.out.println("After filtering, #users: " + map_user_count.size());
System.out.println("Write the filtered file in: " + outputfile);
}
/**
* Filter a user if his/her number of reviews is less than the input threshold min_reviews.
* @param inputfileDir
* @param dataset
* @param min_reviews
* @throws IOException
*/
public void FilterVotesFileByUsers(String inputfileDir, String dataset, int min_reviews) throws IOException {
String inputfile = inputfileDir + dataset + ".votes";
String outputfile= inputfileDir + dataset + "_u" + min_reviews + ".votes";
reader = new BufferedReader(new InputStreamReader(new FileInputStream(inputfile)));
System.out.println("Filtering " + inputfile + " with min_reviews per user: " + min_reviews);
// Step 1: count how many reviews per user.
HashMap map_user_count = new HashMap();
String line;
while ((line = reader.readLine()) != null ) {
String user_id = line.split(" ")[0];
if (!map_user_count.containsKey(user_id)) {
map_user_count.put(user_id, 0);
}
map_user_count.put(user_id, map_user_count.get(user_id) + 1);
}
reader.close();
System.out.println("Before filtering, #users: " + map_user_count.size());
// Step 2: output the new filtered file.
reader = new BufferedReader(new InputStreamReader(new FileInputStream(inputfile)));
PrintWriter writer = new PrintWriter (new FileOutputStream(outputfile));
while ((line = reader.readLine())!= null) {
String user_id = line.split(" ")[0];
if (map_user_count.containsKey(user_id) && map_user_count.get(user_id) >= min_reviews) {
writer.println(line);
} else {
map_user_count.remove(user_id);
}
}
reader.close();
writer.close();
System.out.println("After filtering, #users: " + map_user_count.size());
System.out.println("Write the filtered file in: " + outputfile);
}
/**
* Filter an item if its number of reviews is less than the input threshold min_reviews.
* @param inputfileDir
* @param dataset
* @param min_reivews
* @throws IOException
*/
public void FilterVotesFileByItems(String inputfileDir, String dataset, int min_reviews) throws IOException {
String inputfile = inputfileDir + dataset + ".votes";
String outputfile= inputfileDir + dataset + "_i" + min_reviews + ".votes";
reader = new BufferedReader(new InputStreamReader(new FileInputStream(inputfile)));
System.out.println("Filtering " + inputfile + " with min_reviews per item: " + min_reviews);
// Step 1: count how many reviews per item.
HashMap map_item_count = new HashMap();
String line;
while ((line = reader.readLine()) != null ) {
String item_id = line.split(" ")[1];
if (!map_item_count.containsKey(item_id)) {
map_item_count.put(item_id, 0);
}
map_item_count.put(item_id, map_item_count.get(item_id) + 1);
}
reader.close();
System.out.println("Before filtering, #item: " + map_item_count.size());
// Step 2: output the new filtered file.
reader = new BufferedReader(new InputStreamReader(new FileInputStream(inputfile)));
PrintWriter writer = new PrintWriter (new FileOutputStream(outputfile));
while ((line = reader.readLine())!= null) {
String item_id = line.split(" ")[1];
if (map_item_count.containsKey(item_id) && map_item_count.get(item_id) >= min_reviews) {
writer.println(line);
} else {
map_item_count.remove(item_id);
}
}
reader.close();
writer.close();
System.out.println("After filtering, #items: " + map_item_count.size());
System.out.println("Write the filtered file in: " + outputfile);
}
/**
* Check the user overlap of two votes datasets
* @param dir
* @param dataset1
* @param dataset2
* @throws IOException
*/
public void checkOverlapUsers(String dir, String dataset1, String dataset2) throws IOException {
String file1 = dir + dataset1 + ".votes";
String file2 = dir + dataset2 + ".votes";
int userIndex = 0; // the index of user in votes file.
// Read users of dataset1
HashSet users1 = new HashSet();
reader = new BufferedReader(new InputStreamReader(new FileInputStream(file1)));
String line;
int count=0;
while ((line = reader.readLine()) != null) {
if (count++ % 100000 == 0) System.out.print(".");
String user = line.split(" ")[userIndex];
users1.add(user);
}
reader.close();
System.out.println("");
// Read users of dataset2
HashSet users2 = new HashSet();
reader = new BufferedReader(new InputStreamReader(new FileInputStream(file2)));
while ((line = reader.readLine()) != null) {
if (count++ % 100000 == 0) System.out.print(".");
String user = line.split(" ")[userIndex];
users2.add(user);
}
System.out.println("");
HashSet intersection = new HashSet(users1);
intersection.retainAll(users2);
System.out.printf("#overlap users of <%s, %s>: %d \t %.2f%%, %.2f%%\n",
dataset1, dataset2, intersection.size(), intersection.size()/ (users1.size()/100.0),
intersection.size()/ (users2.size()/100.0));
reader.close();
}
/**
* Only retain top occurrence words of a review.
* @param inputfileDir
* @param dataset
* @param maxWords The number of (top occurrence) words in the word dictionary.
* @throws IOException
* @throws LangDetectException
*/
public void FilterVotesReviewsByWords(String inputfileDir, String dataset, int maxWords)
throws IOException {
String inputfile = inputfileDir + dataset + ".votes";
String outputfile = inputfileDir + dataset + "_w" + maxWords/1000 + "k.votes";
reader = new BufferedReader(new InputStreamReader(new FileInputStream(inputfile)));
System.out.print("\nFiltering reviews by words: " + dataset);
// Step 1: Build word dictionary.
HashMap map_word_id = new HashMap();
this.buildWordsDictionary(inputfile, map_word_id, maxWords);
// Step 2: Write the filtered file.
PrintWriter writer = new PrintWriter (new FileOutputStream(outputfile));
String line;
while ((line = reader.readLine()) != null) {
String[] arr = line.split(" ");
String filtered_review_text = "";
int wordcount = 0;
for (int i = 5; i < arr.length; i++) {
String word = arr[i];
if (map_word_id.containsKey(word)) {
wordcount ++;
filtered_review_text = filtered_review_text + word + " ";
}
}
writer.printf("%s %s %s %s %d %s\n",
arr[0], arr[1], arr[2], arr[3], wordcount, filtered_review_text);
}
System.out.println("\nWrite the filtered file in: " + outputfile);
writer.close();
reader.close();
}
/**
* Write a matrix into file.
* Format of each line: row_id [non-zero entryCount]: (col1, val1), (col2, val2) ...
* @param matrix
* @param filename
* @throws IOException
*/
public static void writeMatrixToFile(SparseMatrix matrix, String filename) throws IOException {
PrintWriter writer = new PrintWriter (new FileOutputStream(filename));
int rowCount = matrix.length()[0];
for (int i = 1; i < rowCount; i++) {
ArrayList indexList = matrix.getRowRef(i).indexList();
String line;
if (indexList.size() == 0) {
line = String.format("%d [0]:\t", i);
} else {
line = String.format("%d [%d]:\t", i, indexList.size());
for (int j : indexList) {
line += String.format("(%d, %.4f)\t", j, matrix.getValue(i, j));
}
}
writer.println(line);
}
writer.close();
}
/**
* Process the .lexicon file (generate by thuir-sentires.rar tool), and generate feature set.
* Select top features by descending order of number of opinions.
*
* @param lexiconFile
* @param aspectRatio Percentage of top aspects to read.
* @return
* @throws IOException
*/
static public HashMap> loadFeaturesFromLexiconFile(String lexiconFile,
double aspectRatio) throws IOException {
HashMap> map_feature_opinion =
new HashMap>();
// System.out.println("Loading features from lexicon file: " + lexiconFile);
BufferedReader reader = new BufferedReader(new InputStreamReader(new FileInputStream(lexiconFile)));
String line;
while ((line = reader.readLine()) != null) {
String feature_opinion = line.split("\t")[1];
String feature = feature_opinion.split("\\|")[0].replaceAll("!", "").trim();
String opinion = feature_opinion.split("\\|")[1].trim();
if (!map_feature_opinion.containsKey(feature)) {
map_feature_opinion.put(feature, new HashSet());
}
map_feature_opinion.get(feature).add(opinion);
}
System.out.println("Feature count in total: " + map_feature_opinion.size());
// Select features by descending order of number of opinions.
int aspectNum = (int) (map_feature_opinion.size() * aspectRatio);
HashMap map_feature_count = new HashMap();
for (Map.Entry> entry : map_feature_opinion.entrySet()) {
map_feature_count.put(entry.getKey(), entry.getValue().size());
}
HashSet topFeatures = new HashSet(
CommonUtils.TopKeysByValue(map_feature_count, aspectNum, null));
Set featureSet = new HashSet(map_feature_opinion.keySet());
for (String feature : featureSet) {
if (!topFeatures.contains(feature)) {
map_feature_opinion.remove(feature);
}
}
reader.close();
// System.out.println("# of features loaded: " + map_feature_opinion.size());
return map_feature_opinion;
}
static public HashSet loadFeaturesFromFeatureFile(String featureFile)
throws IOException {
BufferedReader reader = new BufferedReader(new InputStreamReader(new FileInputStream(featureFile)));
HashSet features = new HashSet();
String line;
while ((line = reader.readLine()) != null) {
String[] arr = line.trim().split("\t");
if (arr != null && arr.length > 1) {
features.add(arr[0]);
}
}
reader.close();
return features;
}
/**
* Load positive Features.
* @param lexiconFile
* @return
* @throws IOException
*/
static public HashMap> loadPosFeaturesFromLexiconFile(String lexiconFile,
double aspectRatio) throws IOException {
HashMap> map_feature_opinion =
new HashMap>();
// System.out.println("Loading features from lexicon file: " + lexiconFile);
BufferedReader reader = new BufferedReader(new InputStreamReader(new FileInputStream(lexiconFile)));
String line;
while ((line = reader.readLine()) != null) {
if (!line.contains("[1]")) continue; // Only consider positive FO pairs.
String feature_opinion = line.split("\t")[1];
String feature = feature_opinion.split("\\|")[0].replaceAll("!", "").trim();
String opinion = feature_opinion.split("\\|")[1].trim();
if (!map_feature_opinion.containsKey(feature)) {
map_feature_opinion.put(feature, new HashSet());
}
map_feature_opinion.get(feature).add(opinion);
}
System.out.printf("Feature count in total: %d. ", map_feature_opinion.size());
// Select features by descending order of number of opinions.
int aspectNum = (int) (map_feature_opinion.size() * aspectRatio);
HashMap map_feature_count = new HashMap();
for (Map.Entry> entry : map_feature_opinion.entrySet()) {
map_feature_count.put(entry.getKey(), entry.getValue().size());
}
HashSet topFeatures = new HashSet(
CommonUtils.TopKeysByValue(map_feature_count, aspectNum, null));
Set featureSet = new HashSet(map_feature_opinion.keySet());
for (String feature : featureSet) {
if (!topFeatures.contains(feature)) {
map_feature_opinion.remove(feature);
}
}
// Count number of F-O pairs.
int count = 0;
for (String feature : map_feature_opinion.keySet()) {
count += map_feature_opinion.get(feature).size();
}
System.out.printf("Positive F-O pairs: %d. ", count);
reader.close();
return map_feature_opinion;
}
/**
* Filter FO pairs that do not occur in the training votes file.
*/
static public HashMap> filterFOpairs(
HashMap> feature_opinions, String votesFile) throws IOException {
HashSet features = new HashSet(feature_opinions.keySet());
HashMap> filteredFO = new HashMap>();
for (String feature : feature_opinions.keySet()) {
filteredFO.put(feature, new ArrayList());
}
BufferedReader reader = new BufferedReader(new InputStreamReader(new FileInputStream(votesFile)));
String line;
while ((line = reader.readLine()) != null) {
Vote vote = parseVotesLine(line);
// Find all features occurred in the review.
HashSet find_features = findFeaturesFromReview(features, vote.review);
for (String feature : find_features) {
// Find all opinions occurred in the review.
HashSet opinions = new HashSet(feature_opinions.get(feature));
opinions = findFeaturesFromReview(opinions, vote.review);
for (String opinion : opinions) {
filteredFO.get(feature).add(opinion);
feature_opinions.get(feature).remove(opinion);
}
}
}
reader.close();
// Count number of F-O pairs.
int count = 0;
for (String feature : filteredFO.keySet()) {
count += filteredFO.get(feature).size();
}
System.out.println("Filtered F-O pairs: " + count);
return filteredFO;
}
/*==============================================================================================
* Private and protected functions.
*==============================================================================================*/
/** Build itemWordsMatrix and userWordsMatrix based on the input user, item and word dictionary.
*
* @param trainFileName
* @param itemWordsMatrix
* @param userWordsMatrix
* @param map_item_id Dictionary of all items (id starts from 1)
* @param map_user_id Dictionary of all users (id starts from 1)
* @param map_word_id Dictionary of all words (id starts from 1)
* @throws IOException
*/
public void buildWordsMatrix(String fileName, SparseMatrix itemWordsMatrix,
SparseMatrix userWordsMatrix, HashMap map_item_id,
HashMap map_user_id, HashMap map_word_id) throws IOException {
BufferedReader reader = new BufferedReader(new InputStreamReader(new FileInputStream(fileName)));
String line;
while ((line = reader.readLine()) != null) {
String[] arr = line.split(" ");
if (arr.length >= 4) {
// Extract item, user and review words.
int userID = map_user_id.get(arr[0]);
int itemID = map_item_id.get(arr[1]);
for (int i = 5; i < arr.length; i++) {
String word = arr[i].trim();
if (map_word_id.containsKey(word)) {
int wordID = map_word_id.get(word);
userWordsMatrix.setValue(userID, wordID, userWordsMatrix.getValue(userID, wordID) + 1);
itemWordsMatrix.setValue(itemID, wordID, itemWordsMatrix.getValue(itemID, wordID) + 1);
}
}
}
}
reader.close();
}
/**
* Build words dictionary (from votes file).
* Only consider English reviews.
*
* @param fileName
* @param map_word_id Save the results of word dictionary.
* @param maxWords The maximum words in the dictionary (select top words). To disable the function, set it as 0.
* @throws IOException
* @throws LangDetectException
*/
public void buildWordsDictionary(String fileName, HashMap map_word_id,
int maxWords) throws IOException {
BufferedReader reader = new BufferedReader(new InputStreamReader(new FileInputStream(fileName)));
// map from word to its number of occurrence.
HashMap map_word_count = new HashMap();
StopwordsFilter.init("lib/stopwords.txt");
String line;
int linecount=0;
while ((line = reader.readLine()) != null) {
Vote vote = parseVotesLine(line);
if (vote!=null) {
// Process review words.
String review_text = vote.review.trim();
// if (!LanguageDetector.isEnglish(review_text)) continue; // Filter nonEnglish reviews.
for (String word : review_text.split(" ")) {
if (StopwordsFilter.isStopword(word)) continue; // Filter stopwords.
if (word.matches(".*\\d+.*")) continue;// Filter word that contains digit.
if (!map_word_count.containsKey(word)) {
map_word_count.put(word, 0);
}
map_word_count.put(word, map_word_count.get(word) + 1);
}
}
if (linecount % 10000 == 0) {
System.out.print(".");
}
linecount++;
}
// System.out.print("\nBefore filtering, dictionary_size: " + map_word_count.size() +", after filtering: " + maxWords);
// Use the most frequent maxWords as the word dictionary.
List> sortedMap;
if (maxWords > 0) {
sortedMap = mostFrequentEntries(map_word_count, maxWords);
} else {
sortedMap = CommonUtils.SortMapByValue(map_word_count);
}
// Words are sorted by its number of occurrence.
for (Map.Entry entity : sortedMap) {
map_word_id.put(entity.getKey(), map_word_id.size());
}
reader.close();
}
/**
* Build aspects matrix of the input .votes file, given the user/item/aspect dictionary.
* In this function, an aspect represents a F-O pair.
* @param votesFile
* @param map_user_id
* @param map_item_id
* @param map_aspect_id
* @param itemAspect
* @param userAspect
* @throws IOException
*/
static public void buildAspectsMatrix_FO(String votesFile, HashMap map_user_id,
HashMap