Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
@@ -0,0 +1,8 @@
package com.spbsu.flamestream.example.bl.text_classifier.ops.filtering.classifier;

import com.expleague.commons.math.vectors.Mx;
import com.expleague.commons.math.vectors.impl.mx.SparseMx;

public interface Optimizer {
SparseMx optimizeWeights(Mx trainingSet, String[] correctTopics, SparseMx prevWeights);
}
Original file line number Diff line number Diff line change
@@ -1,15 +1,16 @@
package com.spbsu.flamestream.example.bl.text_classifier.ops.filtering.classifier;

import com.expleague.commons.math.vectors.Mx;
import com.expleague.commons.math.vectors.MxTools;
import com.expleague.commons.math.vectors.Vec;
import com.expleague.commons.math.vectors.VecTools;
import com.expleague.commons.math.vectors.impl.mx.RowsVecArrayMx;
import com.expleague.commons.math.vectors.impl.mx.SparseMx;
import com.expleague.commons.math.vectors.impl.vectors.ArrayVec;
import com.expleague.commons.math.vectors.impl.vectors.SparseVec;
import com.google.common.annotations.VisibleForTesting;
import gnu.trove.map.TObjectIntMap;
import gnu.trove.map.hash.TObjectIntHashMap;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

import java.io.BufferedReader;
import java.io.File;
Expand All @@ -24,6 +25,7 @@
import java.util.stream.StreamSupport;

public class SklearnSgdPredictor implements TopicsPredictor {
private static final Logger LOGGER = LoggerFactory.getLogger(SklearnSgdPredictor.class.getName());
private static final Pattern PATTERN = Pattern.compile("\\b\\w\\w+\\b", Pattern.UNICODE_CHARACTER_CLASS);

private final String weightsPath;
Expand All @@ -32,43 +34,45 @@ public class SklearnSgdPredictor implements TopicsPredictor {
//lazy loading
private TObjectIntMap<String> countVectorizer;
private Vec intercept;
private Mx weights;
private SparseMx weights;
private String[] topics;

public SklearnSgdPredictor(String cntVectorizerPath, String weightsPath) {
this.weightsPath = weightsPath;
this.cntVectorizerPath = cntVectorizerPath;
}

private SparseVec vectorize(Map<String, Double> tfIdf) {
final int[] indices = new int[tfIdf.size()];
final double[] values = new double[tfIdf.size()];

int ind = 0;
for (String key : tfIdf.keySet()) {
final int valueIndex = countVectorizer.get(key);
indices[ind] = valueIndex;
values[ind] = tfIdf.get(key);
ind++;
}

return new SparseVec(countVectorizer.size(), indices, values);
}

@Override
public Topic[] predict(Document document) {
loadMeta();
loadVocabulary();

final Map<String, Double> tfIdf = document.tfIdf();
final int[] indices = new int[tfIdf.size()];
final double[] values = new double[tfIdf.size()];
{ //convert TF-IDF features to sparse vector
int ind = 0;
for (String key : tfIdf.keySet()) {
final int valueIndex = countVectorizer.get(key);
indices[ind] = valueIndex;
values[ind] = tfIdf.get(key);
ind++;
}
}

final Vec probabilities;
{ // compute topic probabilities
final SparseVec vectorized = new SparseVec(countVectorizer.size(), indices, values);
final SparseVec vectorized = vectorize(document.tfIdf());
final Vec score = MxTools.multiply(weights, vectorized);
final Vec sum = VecTools.sum(score, intercept);
final Vec scaled = VecTools.scale(sum, -1);
VecTools.exp(scaled);

final double[] ones = new double[score.dim()];
Arrays.fill(ones, 1);
final Vec vecOnes = new ArrayVec(ones, 0, ones.length);
final Vec vecOnes = new ArrayVec(score.dim());
VecTools.fill(vecOnes, 1);

probabilities = VecTools.sum(scaled, vecOnes);
for (int i = 0; i < probabilities.dim(); i++) {
double changed = 1 / probabilities.get(i);
Expand All @@ -87,6 +91,19 @@ public Topic[] predict(Document document) {
return result;
}

@Override
public void updateWeights(SparseMx weights) {
this.weights = weights;
}

public SparseMx getWeights() {
return weights;
}

public String[] getTopics() {
return topics;
}

public void init() {
loadMeta();
loadVocabulary();
Expand All @@ -107,7 +124,7 @@ private void loadMeta() {
topics[i] = br.readLine();
}

final Vec[] coef = new Vec[classes];
final SparseVec[] coef = new SparseVec[classes];
String line;
for (int index = 0; index < classes; index++) {
line = br.readLine();
Expand All @@ -123,11 +140,10 @@ private void loadMeta() {
values[i / 2] = value;
}

final SparseVec sparseVec = new SparseVec(currentFeatures, indeces, values);
coef[index] = sparseVec;
coef[index] = new SparseVec(currentFeatures, indeces, values);
}

weights = new RowsVecArrayMx(coef);
weights = new SparseMx(coef);
MxTools.transpose(weights);

line = br.readLine();
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,174 @@
package com.spbsu.flamestream.example.bl.text_classifier.ops.filtering.classifier;

import com.expleague.commons.math.vectors.Mx;
import com.expleague.commons.math.vectors.MxTools;
import com.expleague.commons.math.vectors.Vec;
import com.expleague.commons.math.vectors.VecTools;
import com.expleague.commons.math.vectors.impl.mx.SparseMx;
import com.expleague.commons.math.vectors.impl.vectors.ArrayVec;
import com.expleague.commons.math.vectors.impl.vectors.SparseVec;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

import java.util.Arrays;
import java.util.List;
import java.util.concurrent.*;
import java.util.concurrent.atomic.AtomicReference;
import java.util.stream.Stream;

public class SoftmaxRegressionOptimizer implements Optimizer {
private static final Logger LOGGER = LoggerFactory.getLogger(SoftmaxRegressionOptimizer.class.getName());
private final List<String> topicList;
private ExecutorService executor = Executors.newFixedThreadPool(8);

public SoftmaxRegressionOptimizer(String[] topics) {
topicList = Arrays.asList(topics);
}

private Mx l1Gradient(SparseMx weights) {
final Mx gradient = new SparseMx(weights.rows(), weights.columns());

for (int i = 0; i < weights.rows(); i++) {
for (int j = 0; j < weights.columns(); j++) {
gradient.set(i, j, Math.signum(weights.get(i, j)));
}
}

return gradient;
}

private Mx l2Gradient(SparseMx weights, SparseMx prevWeights) {
//return VecTools.subtract(VecTools.scale(weights, 2), prevWeights); ???
Mx gradient = new SparseMx(weights.rows(), weights.columns());

for (int i = 0; i < weights.rows(); i++) {
for (int j = 0; j < weights.columns(); j++) {
gradient.set(i, j, 2 * (weights.get(i, j) - prevWeights.get(i, j)));
}
}

return gradient;
}

private Vec computeSoftmaxValues(SparseMx weights, Mx trainingSet, int[] correctTopics) {
final double[] softmaxValues = new double[trainingSet.rows()];

CountDownLatch latch = new CountDownLatch(trainingSet.rows());
for (int i = 0; i < trainingSet.rows(); i++) {
final int finalI = i;
executor.execute(() -> {
final Vec x = VecTools.copySparse(trainingSet.row(finalI));
final int index = correctTopics[finalI];
final Vec mul = MxTools.multiply(weights, x);
VecTools.exp(mul);

final double numer = mul.get(index);
double denom = 0.0;
for (int k = 0; k < weights.rows(); k++) {
denom += mul.get(k);
}

//LOGGER.info("values size {} and setting value index {}", softmaxValues.dim(), finalI);
//System.out.println("size " + softmaxValues.length + " index" + finalI);
softmaxValues[finalI] = numer / denom;
latch.countDown();
});
}

try {
latch.await();
} catch (InterruptedException e) {
e.printStackTrace();
}
return new ArrayVec(softmaxValues, 0, softmaxValues.length);
//return softmaxValues;
}

private SoftmaxData softmaxGradient(SparseMx weights, Mx trainingSet, int[] correctTopics) {
final SparseVec[] gradients = new SparseVec[weights.rows()];
final Vec softmaxValues = computeSoftmaxValues(weights, trainingSet, correctTopics);

CountDownLatch latch = new CountDownLatch(weights.rows());
for (int j = 0; j < weights.rows(); j++) {
//LOGGER.info("weights {} component", j);
final int finalJ = j;

executor.execute(() -> {
SparseVec grad = new SparseVec(weights.columns());
final SparseVec scales = new SparseVec(trainingSet.rows());
for (int i = 0; i < trainingSet.rows(); i++) {
final int index = correctTopics[i];
final int indicator = index == finalJ ? 1 : 0;
scales.set(i, indicator - softmaxValues.get(i));
}

for (int i = 0; i < trainingSet.rows(); i++) {
final Vec x = VecTools.copySparse(trainingSet.row(i));
VecTools.scale(x, scales);
grad = VecTools.sum(grad, x);

}

gradients[finalJ] = grad;//VecTools.scale(grad, -1.0 / trainingSet.rows());
latch.countDown();
});
}

try {
latch.await();
} catch (InterruptedException e) {
e.printStackTrace();
}
return new SoftmaxData(VecTools.sum(softmaxValues), new SparseMx(gradients));
}

public SparseMx optimizeWeights(Mx trainingSet, String[] correctTopics, SparseMx prevWeights) {
final double alpha = 1e-1;
final double lambda1 = 1e-2;
final double lambda2 = 1e-1;
final double maxIter = 100;
final int[] indeces = Stream.of(correctTopics).mapToInt(topicList::indexOf).toArray();

double previousValue = 0;
SparseMx weights = new SparseMx(prevWeights.rows(), prevWeights.columns());
for (int iteration = 1; iteration <= maxIter; iteration++) {
LOGGER.info("Iteration {}", iteration);
final SoftmaxData data = softmaxGradient(weights, trainingSet, indeces);
LOGGER.info("Softmax value : {}", data.value);
if (Math.abs(data.value - previousValue) < 1e-3) {
break;
}

previousValue = data.value;
Mx l1 = l1Gradient(weights);
Mx l2 = l2Gradient(weights, prevWeights);

//SoftmaxData = VecTools.scale(SoftmaxData, alpha);
//l1 = VecTools.scale(l1, lambda1);
//l2 = VecTools.scale(l2, lambda2);
// weights = VecTools.subtract(weights, VecTools.sum(SoftmaxData, VecTools.sum(l1, l2)));

for (int i = 0; i < weights.rows(); i++) {
for (int j = 0; j < weights.columns(); j++) {
final double value = weights.get(i, j)
- alpha * (data.gradients.get(i, j) - lambda1 * l1.get(i, j) - lambda2 * l2.get(i, j));
weights.set(i, j, value);
}
}

}

return weights;
}

private class SoftmaxData {
private final double value;
private final SparseMx gradients;

SoftmaxData(double value, SparseMx gradients) {
this.value = value;
this.gradients = gradients;
}
}

}
Original file line number Diff line number Diff line change
@@ -1,8 +1,11 @@
package com.spbsu.flamestream.example.bl.text_classifier.ops.filtering.classifier;

import com.expleague.commons.math.vectors.impl.mx.SparseMx;

public interface TopicsPredictor {
default void init() {
}

Topic[] predict(Document document);
void updateWeights(SparseMx weights);
}
4 changes: 2 additions & 2 deletions examples/src/main/resources/classifier_weights
Git LFS file not shown
4 changes: 2 additions & 2 deletions examples/src/main/resources/cnt_vectorizer
Git LFS file not shown
Original file line number Diff line number Diff line change
Expand Up @@ -156,7 +156,7 @@ public Stream<Topic[]> apply(Document document) {
}
}

private static double[] parseDoubles(String line) {
public static double[] parseDoubles(String line) {
return Arrays
.stream(line.split(" "))
.mapToDouble(Double::parseDouble)
Expand Down
Loading