package org.apache.mahout.cf.taste.hadoop.als;

import java.io.File;
import java.util.Iterator;
import org.apache.hadoop.conf.Configuration;
import org.apache.hadoop.fs.Path;
import org.apache.hadoop.util.ToolRunner;
import org.apache.mahout.cf.taste.hadoop.TasteHadoopUtils;
import org.apache.mahout.cf.taste.impl.TasteTestCase;
import org.apache.mahout.cf.taste.impl.common.FullRunningAverage;
import org.apache.mahout.math.DenseVector;
import org.apache.mahout.math.Matrix;
import org.apache.mahout.math.MatrixSlice;
import org.apache.mahout.math.SparseRowMatrix;
import org.apache.mahout.math.Vector;
import org.apache.mahout.math.hadoop.MathHelper;
import org.apache.mahout.math.map.OpenIntObjectHashMap;
import org.junit.Before;
import org.junit.Test;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

/* loaded from: input_file:org/apache/mahout/cf/taste/hadoop/als/ParallelALSFactorizationJobTest.class */
public class ParallelALSFactorizationJobTest extends TasteTestCase {
    private static final Logger log = LoggerFactory.getLogger(ParallelALSFactorizationJobTest.class);
    private File inputFile;
    private File intermediateDir;
    private File outputDir;
    private File tmpDir;
    private Configuration conf;

    @Override // org.apache.mahout.common.MahoutTestCase
    @Before
    public void setUp() throws Exception {
        super.setUp();
        this.inputFile = getTestTempFile("prefs.txt");
        this.intermediateDir = getTestTempDir("intermediate");
        this.intermediateDir.delete();
        this.outputDir = getTestTempDir("output");
        this.outputDir.delete();
        this.tmpDir = getTestTempDir("tmp");
        this.conf = getConfiguration();
        SharingMapper.reset();
    }

    @Test
    public void completeJobToyExample() throws Exception {
        explicitExample(1);
    }

    @Test
    public void completeJobToyExampleMultithreaded() throws Exception {
        explicitExample(2);
    }

    private void explicitExample(int i) throws Exception {
        Double valueOf = Double.valueOf(Double.NaN);
        SparseRowMatrix<MatrixSlice> sparseRowMatrix = new SparseRowMatrix(4, 4, new Vector[]{new DenseVector(new double[]{5.0d, 5.0d, 2.0d, valueOf.doubleValue()}), new DenseVector(new double[]{2.0d, valueOf.doubleValue(), 3.0d, 5.0d}), new DenseVector(new double[]{valueOf.doubleValue(), 5.0d, valueOf.doubleValue(), 3.0d}), new DenseVector(new double[]{3.0d, valueOf.doubleValue(), valueOf.doubleValue(), 5.0d})});
        writeLines(this.inputFile, preferencesAsText(sparseRowMatrix));
        ParallelALSFactorizationJob parallelALSFactorizationJob = new ParallelALSFactorizationJob();
        parallelALSFactorizationJob.setConf(this.conf);
        parallelALSFactorizationJob.run(new String[]{"--input", this.inputFile.getAbsolutePath(), "--output", this.outputDir.getAbsolutePath(), "--tempDir", this.tmpDir.getAbsolutePath(), "--lambda", String.valueOf(0.065d), "--numFeatures", String.valueOf(3), "--numIterations", String.valueOf(5), "--numThreadsPerSolver", String.valueOf(i)});
        Matrix readMatrix = MathHelper.readMatrix(this.conf, new Path(this.outputDir.getAbsolutePath(), "U/part-m-00000"), sparseRowMatrix.numRows(), 3);
        Matrix readMatrix2 = MathHelper.readMatrix(this.conf, new Path(this.outputDir.getAbsolutePath(), "M/part-m-00000"), sparseRowMatrix.numCols(), 3);
        StringBuilder sb = new StringBuilder();
        sb.append("\nA - users x items\n\n");
        sb.append(MathHelper.nice((Matrix) sparseRowMatrix));
        sb.append("\nU - users x features\n\n");
        sb.append(MathHelper.nice(readMatrix));
        sb.append("\nM - items x features\n\n");
        sb.append(MathHelper.nice(readMatrix2));
        Matrix times = readMatrix.times(readMatrix2.transpose());
        sb.append("\nAk - users x items\n\n");
        sb.append(MathHelper.nice(times));
        sb.append('\n');
        log.info(sb.toString());
        FullRunningAverage fullRunningAverage = new FullRunningAverage();
        for (MatrixSlice matrixSlice : sparseRowMatrix) {
            for (Vector.Element element : matrixSlice.nonZeroes()) {
                if (!Double.isNaN(element.get())) {
                    double d = element.get();
                    double dot = readMatrix.viewRow(matrixSlice.index()).dot(readMatrix2.viewRow(element.index()));
                    double d2 = d - dot;
                    fullRunningAverage.addDatum(d2 * d2);
                    log.info("Comparing preference of user [{}] towards item [{}], was [{}] estimate is [{}]", new Object[]{Integer.valueOf(matrixSlice.index()), Integer.valueOf(element.index()), Double.valueOf(d), Double.valueOf(dot)});
                }
            }
        }
        double sqrt = Math.sqrt(fullRunningAverage.getAverage());
        log.info("RMSE: {}", Double.valueOf(sqrt));
        assertTrue(sqrt < 0.2d);
    }

    @Test
    public void completeJobImplicitToyExample() throws Exception {
        implicitExample(1);
    }

    @Test
    public void completeJobImplicitToyExampleMultithreaded() throws Exception {
        implicitExample(2);
    }

    public void implicitExample(int i) throws Exception {
        SparseRowMatrix sparseRowMatrix = new SparseRowMatrix(4, 4, new Vector[]{new DenseVector(new double[]{5.0d, 5.0d, 2.0d, 0.0d}), new DenseVector(new double[]{2.0d, 0.0d, 3.0d, 5.0d}), new DenseVector(new double[]{0.0d, 5.0d, 0.0d, 3.0d}), new DenseVector(new double[]{3.0d, 0.0d, 0.0d, 5.0d})});
        SparseRowMatrix<MatrixSlice> sparseRowMatrix2 = new SparseRowMatrix(4, 4, new Vector[]{new DenseVector(new double[]{1.0d, 1.0d, 1.0d, 0.0d}), new DenseVector(new double[]{1.0d, 0.0d, 1.0d, 1.0d}), new DenseVector(new double[]{0.0d, 1.0d, 0.0d, 1.0d}), new DenseVector(new double[]{1.0d, 0.0d, 0.0d, 1.0d})});
        writeLines(this.inputFile, preferencesAsText(sparseRowMatrix));
        ParallelALSFactorizationJob parallelALSFactorizationJob = new ParallelALSFactorizationJob();
        parallelALSFactorizationJob.setConf(this.conf);
        parallelALSFactorizationJob.run(new String[]{"--input", this.inputFile.getAbsolutePath(), "--output", this.outputDir.getAbsolutePath(), "--tempDir", this.tmpDir.getAbsolutePath(), "--lambda", String.valueOf(0.065d), "--implicitFeedback", String.valueOf(true), "--alpha", String.valueOf(20.0d), "--numFeatures", String.valueOf(3), "--numIterations", String.valueOf(5), "--numThreadsPerSolver", String.valueOf(i)});
        Matrix readMatrix = MathHelper.readMatrix(this.conf, new Path(this.outputDir.getAbsolutePath(), "U/part-m-00000"), sparseRowMatrix.numRows(), 3);
        Matrix readMatrix2 = MathHelper.readMatrix(this.conf, new Path(this.outputDir.getAbsolutePath(), "M/part-m-00000"), sparseRowMatrix.numCols(), 3);
        StringBuilder sb = new StringBuilder();
        sb.append("\nObservations - users x items\n");
        sb.append(MathHelper.nice((Matrix) sparseRowMatrix));
        sb.append("\nA - users x items\n\n");
        sb.append(MathHelper.nice((Matrix) sparseRowMatrix2));
        sb.append("\nU - users x features\n\n");
        sb.append(MathHelper.nice(readMatrix));
        sb.append("\nM - items x features\n\n");
        sb.append(MathHelper.nice(readMatrix2));
        Matrix times = readMatrix.times(readMatrix2.transpose());
        sb.append("\nAk - users x items\n\n");
        sb.append(MathHelper.nice(times));
        sb.append('\n');
        log.info(sb.toString());
        FullRunningAverage fullRunningAverage = new FullRunningAverage();
        for (MatrixSlice matrixSlice : sparseRowMatrix2) {
            for (Vector.Element element : matrixSlice.nonZeroes()) {
                if (!Double.isNaN(element.get())) {
                    double d = element.get();
                    double dot = readMatrix.viewRow(matrixSlice.index()).dot(readMatrix2.viewRow(element.index()));
                    double quick = 1.0d + (20.0d * sparseRowMatrix.getQuick(matrixSlice.index(), element.index()));
                    fullRunningAverage.addDatum(quick * (d - dot) * (d - dot));
                    log.info("Comparing preference of user [{}] towards item [{}], was [{}] with confidence [{}] estimate is [{}]", new Object[]{Integer.valueOf(matrixSlice.index()), Integer.valueOf(element.index()), Double.valueOf(d), Double.valueOf(quick), Double.valueOf(dot)});
                }
            }
        }
        double sqrt = Math.sqrt(fullRunningAverage.getAverage());
        log.info("RMSE: {}", Double.valueOf(sqrt));
        assertTrue(sqrt < 0.4d);
    }

    @Test
    public void exampleWithIDMapping() throws Exception {
        String[] strArr = {"5568227754922264005,-4758971626494767444,5.0", "5568227754922264005,3688396615879561990,5.0", "5568227754922264005,4594226737871995304,2.0", "550945997885173934,-4758971626494767444,2.0", "550945997885173934,4594226737871995304,3.0", "550945997885173934,706816485922781596,5.0", "2448095297482319463,3688396615879561990,5.0", "2448095297482319463,706816485922781596,3.0", "6839920411763636962,-4758971626494767444,3.0", "6839920411763636962,706816485922781596,5.0"};
        writeLines(this.inputFile, strArr);
        ParallelALSFactorizationJob parallelALSFactorizationJob = new ParallelALSFactorizationJob();
        parallelALSFactorizationJob.setConf(this.conf);
        parallelALSFactorizationJob.run(new String[]{"--input", this.inputFile.getAbsolutePath(), "--output", this.outputDir.getAbsolutePath(), "--tempDir", this.tmpDir.getAbsolutePath(), "--lambda", String.valueOf(0.065d), "--numFeatures", String.valueOf(3), "--numIterations", String.valueOf(5), "--numThreadsPerSolver", String.valueOf(1), "--usesLongIDs", String.valueOf(true)});
        assertEquals(4L, TasteHadoopUtils.readIDIndexMap(this.outputDir.getAbsolutePath() + "/userIDIndex/part-r-00000", this.conf).size());
        assertEquals(4L, TasteHadoopUtils.readIDIndexMap(this.outputDir.getAbsolutePath() + "/itemIDIndex/part-r-00000", this.conf).size());
        OpenIntObjectHashMap<Vector> readMatrixRows = MathHelper.readMatrixRows(this.conf, new Path(this.outputDir.getAbsolutePath(), "U/part-m-00000"));
        OpenIntObjectHashMap<Vector> readMatrixRows2 = MathHelper.readMatrixRows(this.conf, new Path(this.outputDir.getAbsolutePath(), "M/part-m-00000"));
        assertEquals(4L, readMatrixRows.size());
        assertEquals(4L, readMatrixRows2.size());
        FullRunningAverage fullRunningAverage = new FullRunningAverage();
        for (String str : strArr) {
            String[] splitPrefTokens = TasteHadoopUtils.splitPrefTokens(str);
            double parseDouble = Double.parseDouble(splitPrefTokens[2]) - ((Vector) readMatrixRows.get(TasteHadoopUtils.idToIndex(Long.parseLong(splitPrefTokens[0])))).dot((Vector) readMatrixRows2.get(TasteHadoopUtils.idToIndex(Long.parseLong(splitPrefTokens[1]))));
            fullRunningAverage.addDatum(parseDouble * parseDouble);
        }
        double sqrt = Math.sqrt(fullRunningAverage.getAverage());
        log.info("RMSE: {}", Double.valueOf(sqrt));
        assertTrue(sqrt < 0.2d);
    }

    protected static String preferencesAsText(Matrix matrix) {
        StringBuilder sb = new StringBuilder();
        String str = "";
        Iterator it = matrix.iterator();
        while (it.hasNext()) {
            MatrixSlice matrixSlice = (MatrixSlice) it.next();
            for (Vector.Element element : matrixSlice.nonZeroes()) {
                if (!Double.isNaN(element.get())) {
                    sb.append(str).append(matrixSlice.index()).append(',').append(element.index()).append(',').append(element.get());
                    str = "\n";
                }
            }
        }
        System.out.println(sb.toString());
        return sb.toString();
    }

    @Test
    public void recommenderJobWithIDMapping() throws Exception {
        writeLines(this.inputFile, "5568227754922264005,-4758971626494767444,5.0", "5568227754922264005,3688396615879561990,5.0", "5568227754922264005,4594226737871995304,2.0", "550945997885173934,-4758971626494767444,2.0", "550945997885173934,4594226737871995304,3.0", "550945997885173934,706816485922781596,5.0", "2448095297482319463,3688396615879561990,5.0", "2448095297482319463,706816485922781596,3.0", "6839920411763636962,-4758971626494767444,3.0", "6839920411763636962,706816485922781596,5.0");
        new ParallelALSFactorizationJob().setConf(this.conf);
        Configuration configuration = getConfiguration();
        assertEquals(0L, ToolRunner.run(r0, new String[]{"-Dhadoop.tmp.dir=" + configuration.get("hadoop.tmp.dir"), "--input", this.inputFile.getAbsolutePath(), "--output", this.intermediateDir.getAbsolutePath(), "--tempDir", this.tmpDir.getAbsolutePath(), "--lambda", String.valueOf(0.065d), "--numFeatures", String.valueOf(3), "--numIterations", String.valueOf(5), "--numThreadsPerSolver", String.valueOf(1), "--usesLongIDs", String.valueOf(true)}));
        SharingMapper.reset();
        assertEquals(0L, ToolRunner.run(new RecommenderJob(), new String[]{"-Dhadoop.tmp.dir=" + configuration.get("hadoop.tmp.dir"), "--input", this.intermediateDir.getAbsolutePath() + "/userRatings/", "--userFeatures", this.intermediateDir.getAbsolutePath() + "/U/", "--itemFeatures", this.intermediateDir.getAbsolutePath() + "/M/", "--numRecommendations", String.valueOf(2), "--maxRating", String.valueOf(5.0d), "--numThreads", String.valueOf(2), "--usesLongIDs", String.valueOf(true), "--userIDIndex", this.intermediateDir.getAbsolutePath() + "/userIDIndex/", "--itemIDIndex", this.intermediateDir.getAbsolutePath() + "/itemIDIndex/", "--output", this.outputDir.getAbsolutePath()}));
    }
}
