package org.apache.mahout.classifier.df;

import com.google.common.collect.Lists;
import java.util.ArrayList;
import java.util.List;
import java.util.Random;
import org.apache.mahout.classifier.df.builder.DecisionTreeBuilder;
import org.apache.mahout.classifier.df.data.Data;
import org.apache.mahout.classifier.df.data.DataLoader;
import org.apache.mahout.classifier.df.data.Dataset;
import org.apache.mahout.classifier.df.data.DescriptorException;
import org.apache.mahout.common.MahoutTestCase;
import org.apache.mahout.common.RandomUtils;
import org.junit.Test;

@Deprecated
/* loaded from: input_file:org/apache/mahout/classifier/df/DecisionForestTest.class */
public final class DecisionForestTest extends MahoutTestCase {
    private static final String[] TRAIN_DATA = {"sunny,85,85,FALSE,no", "sunny,80,90,TRUE,no", "overcast,83,86,FALSE,yes", "rainy,70,96,FALSE,yes", "rainy,68,80,FALSE,yes", "rainy,65,70,TRUE,no", "overcast,64,65,TRUE,yes", "sunny,72,95,FALSE,no", "sunny,69,70,FALSE,yes", "rainy,75,80,FALSE,yes", "sunny,75,70,TRUE,yes", "overcast,72,90,TRUE,yes", "overcast,81,75,FALSE,yes", "rainy,71,91,TRUE,no"};
    private static final String[] TEST_DATA = {"rainy,70,96,TRUE,-", "overcast,64,65,TRUE,-", "sunny,75,90,TRUE,-"};
    private Random rng;

    @Override // org.apache.mahout.common.MahoutTestCase
    public void setUp() throws Exception {
        super.setUp();
        this.rng = RandomUtils.getRandom();
    }

    private static Data[] generateTrainingDataA() throws DescriptorException {
        Dataset generateDataset = DataLoader.generateDataset("C N N C L", false, TRAIN_DATA);
        Data loadData = DataLoader.loadData(generateDataset, TRAIN_DATA);
        List[] listArr = new List[3];
        for (int i = 0; i < listArr.length; i++) {
            listArr[i] = Lists.newArrayList();
        }
        for (int i2 = 0; i2 < loadData.size(); i2++) {
            if (loadData.get(i2).get(0) == 0.0d) {
                listArr[0].add(loadData.get(i2));
            } else {
                listArr[1].add(loadData.get(i2));
            }
        }
        Data[] dataArr = new Data[listArr.length];
        for (int i3 = 0; i3 < dataArr.length; i3++) {
            dataArr[i3] = new Data(generateDataset, listArr[i3]);
        }
        return dataArr;
    }

    private static Data[] generateTrainingDataB() throws DescriptorException {
        String[] strArr = new String[20];
        for (int i = 0; i < strArr.length; i++) {
            if (i % 3 == 0) {
                strArr[i] = "A," + (40 - i) + ',' + (i + 20);
            } else if (i % 3 == 1) {
                strArr[i] = "B," + (i + 20) + ',' + (40 - i);
            } else {
                strArr[i] = "C," + (i + 20) + ',' + (i + 20);
            }
        }
        Dataset generateDataset = DataLoader.generateDataset("C N L", true, strArr);
        Data[] dataArr = new Data[3];
        dataArr[0] = DataLoader.loadData(generateDataset, strArr);
        String[] strArr2 = new String[20];
        for (int i2 = 0; i2 < strArr2.length; i2++) {
            if (i2 % 2 == 0) {
                strArr2[i2] = "A," + (50 - i2) + ',' + (i2 + 10);
            } else {
                strArr2[i2] = "B," + (i2 + 10) + ',' + (50 - i2);
            }
        }
        dataArr[1] = DataLoader.loadData(generateDataset, strArr2);
        String[] strArr3 = new String[10];
        for (int i3 = 0; i3 < strArr3.length; i3++) {
            strArr3[i3] = "A," + (40 - i3) + ',' + (i3 + 20);
        }
        dataArr[2] = DataLoader.loadData(generateDataset, strArr3);
        return dataArr;
    }

    private DecisionForest buildForest(Data[] dataArr) {
        ArrayList newArrayList = Lists.newArrayList();
        for (Data data : dataArr) {
            DecisionTreeBuilder decisionTreeBuilder = new DecisionTreeBuilder();
            decisionTreeBuilder.setM(data.getDataset().nbAttributes() - 1);
            decisionTreeBuilder.setMinSplitNum(0);
            decisionTreeBuilder.setComplemented(false);
            newArrayList.add(decisionTreeBuilder.build(this.rng, data));
        }
        return new DecisionForest(newArrayList);
    }

    @Test
    public void testClassify() throws DescriptorException {
        Data[] generateTrainingDataA = generateTrainingDataA();
        DecisionForest buildForest = buildForest(generateTrainingDataA);
        Dataset dataset = generateTrainingDataA[0].getDataset();
        Data loadData = DataLoader.loadData(dataset, TEST_DATA);
        double valueOf = dataset.valueOf(4, "no");
        dataset.valueOf(4, "yes");
        assertEquals(valueOf, buildForest.classify(loadData.getDataset(), this.rng, loadData.get(0)), 1.0E-6d);
        assertEquals(valueOf, buildForest.classify(loadData.getDataset(), this.rng, loadData.get(2)), 1.0E-6d);
    }

    /* JADX WARN: Type inference failed for: r0v10, types: [double[], double[][], java.lang.Object[]] */
    /* JADX WARN: Type inference failed for: r0v19, types: [double[], java.lang.Object[]] */
    @Test
    public void testClassifyData() throws DescriptorException {
        Data[] generateTrainingDataA = generateTrainingDataA();
        DecisionForest buildForest = buildForest(generateTrainingDataA);
        Dataset dataset = generateTrainingDataA[0].getDataset();
        Data loadData = DataLoader.loadData(dataset, TEST_DATA);
        ?? r0 = new double[loadData.size()];
        buildForest.classify(loadData, (double[][]) r0);
        double valueOf = dataset.valueOf(4, "no");
        assertArrayEquals(new double[]{new double[]{valueOf, Double.NaN, Double.NaN}, new double[]{valueOf, dataset.valueOf(4, "yes"), Double.NaN}, new double[]{valueOf, valueOf, Double.NaN}}, r0);
    }

    /* JADX WARN: Multi-variable type inference failed */
    /* JADX WARN: Type inference failed for: r0v23, types: [double[], double[][]] */
    /* JADX WARN: Type inference failed for: r0v31, types: [double[], double[][]] */
    /* JADX WARN: Type inference failed for: r0v9, types: [double[], double[][]] */
    @Test
    public void testRegression() throws DescriptorException {
        Data[] generateTrainingDataB = generateTrainingDataB();
        DecisionForest[] decisionForestArr = new DecisionForest[generateTrainingDataB.length];
        for (int i = 0; i < generateTrainingDataB.length; i++) {
            Data[] dataArr = new Data[generateTrainingDataB.length - 1];
            int i2 = 0;
            for (int i3 = 0; i3 < generateTrainingDataB.length; i3++) {
                if (i3 != i) {
                    dataArr[i2] = generateTrainingDataB[i3];
                    i2++;
                }
            }
            decisionForestArr[i] = buildForest(dataArr);
        }
        ?? r0 = new double[generateTrainingDataB[0].size()];
        decisionForestArr[0].classify(generateTrainingDataB[0], (double[][]) r0);
        assertArrayEquals(new double[]{20.0d, 20.0d}, r0[0], 1.0E-6d);
        assertArrayEquals(new double[]{39.0d, 29.0d}, r0[1], 1.0E-6d);
        assertArrayEquals(new double[]{Double.NaN, 29.0d}, r0[2], 1.0E-6d);
        assertArrayEquals(new double[]{Double.NaN, 23.0d}, r0[17], 1.0E-6d);
        ?? r02 = new double[generateTrainingDataB[1].size()];
        decisionForestArr[1].classify(generateTrainingDataB[1], (double[][]) r02);
        assertArrayEquals(new double[]{30.0d, 29.0d}, r02[19], 1.0E-6d);
        ?? r03 = new double[generateTrainingDataB[2].size()];
        decisionForestArr[2].classify(generateTrainingDataB[2], (double[][]) r03);
        assertArrayEquals(new double[]{29.0d, 28.0d}, r03[9], 1.0E-6d);
        assertEquals(20.0d, decisionForestArr[0].classify(generateTrainingDataB[0].getDataset(), this.rng, generateTrainingDataB[0].get(0)), 1.0E-6d);
        assertEquals(34.0d, decisionForestArr[0].classify(generateTrainingDataB[0].getDataset(), this.rng, generateTrainingDataB[0].get(1)), 1.0E-6d);
        assertEquals(29.0d, decisionForestArr[0].classify(generateTrainingDataB[0].getDataset(), this.rng, generateTrainingDataB[0].get(2)), 1.0E-6d);
    }
}
