package org.deeplearning4j.spark.impl.paramavg;

import java.io.IOException;
import java.net.URI;
import java.net.URISyntaxException;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.Collection;
import java.util.Random;
import org.apache.hadoop.fs.FileSystem;
import org.apache.hadoop.fs.Path;
import org.apache.spark.SparkContext;
import org.apache.spark.api.java.JavaPairRDD;
import org.apache.spark.api.java.JavaRDD;
import org.apache.spark.api.java.JavaRDDLike;
import org.apache.spark.api.java.JavaSparkContext;
import org.apache.spark.broadcast.Broadcast;
import org.apache.spark.input.PortableDataStream;
import org.apache.spark.storage.StorageLevel;
import org.deeplearning4j.api.storage.Persistable;
import org.deeplearning4j.api.storage.StatsStorageRouter;
import org.deeplearning4j.api.storage.StatsStorageRouterProvider;
import org.deeplearning4j.api.storage.StorageMetaData;
import org.deeplearning4j.nn.api.Layer;
import org.deeplearning4j.nn.conf.ComputationGraphConfiguration;
import org.deeplearning4j.nn.conf.MultiLayerConfiguration;
import org.deeplearning4j.nn.graph.ComputationGraph;
import org.deeplearning4j.nn.multilayer.MultiLayerNetwork;
import org.deeplearning4j.optimize.api.IterationListener;
import org.deeplearning4j.spark.api.RDDTrainingApproach;
import org.deeplearning4j.spark.api.Repartition;
import org.deeplearning4j.spark.api.RepartitionStrategy;
import org.deeplearning4j.spark.api.TrainingHook;
import org.deeplearning4j.spark.api.TrainingMaster;
import org.deeplearning4j.spark.api.WorkerConfiguration;
import org.deeplearning4j.spark.api.stats.SparkTrainingStats;
import org.deeplearning4j.spark.api.worker.ExecuteWorkerFlatMap;
import org.deeplearning4j.spark.api.worker.ExecuteWorkerMultiDataSetFlatMap;
import org.deeplearning4j.spark.api.worker.ExecuteWorkerPDSFlatMap;
import org.deeplearning4j.spark.api.worker.ExecuteWorkerPDSMDSFlatMap;
import org.deeplearning4j.spark.api.worker.ExecuteWorkerPathFlatMap;
import org.deeplearning4j.spark.api.worker.ExecuteWorkerPathMDSFlatMap;
import org.deeplearning4j.spark.api.worker.NetBroadcastTuple;
import org.deeplearning4j.spark.data.BatchAndExportDataSetsFunction;
import org.deeplearning4j.spark.data.BatchAndExportMultiDataSetsFunction;
import org.deeplearning4j.spark.impl.graph.SparkComputationGraph;
import org.deeplearning4j.spark.impl.graph.dataset.DataSetToMultiDataSetFn;
import org.deeplearning4j.spark.impl.listeners.VanillaStatsStorageRouterProvider;
import org.deeplearning4j.spark.impl.multilayer.SparkDl4jMultiLayer;
import org.deeplearning4j.spark.impl.paramavg.aggregator.ParameterAveragingAggregationTuple;
import org.deeplearning4j.spark.impl.paramavg.aggregator.ParameterAveragingElementAddFunction;
import org.deeplearning4j.spark.impl.paramavg.aggregator.ParameterAveragingElementCombineFunction;
import org.deeplearning4j.spark.impl.paramavg.stats.ParameterAveragingTrainingMasterStats;
import org.deeplearning4j.spark.impl.paramavg.util.ExportSupport;
import org.deeplearning4j.spark.util.SparkUtils;
import org.deeplearning4j.spark.util.serde.StorageLevelDeserializer;
import org.deeplearning4j.spark.util.serde.StorageLevelSerializer;
import org.deeplearning4j.util.UIDProvider;
import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.linalg.api.ops.executioner.GridExecutioner;
import org.nd4j.linalg.dataset.DataSet;
import org.nd4j.linalg.dataset.api.MultiDataSet;
import org.nd4j.linalg.factory.Nd4j;
import org.nd4j.shade.jackson.annotation.JsonAutoDetect;
import org.nd4j.shade.jackson.annotation.JsonIgnoreProperties;
import org.nd4j.shade.jackson.annotation.PropertyAccessor;
import org.nd4j.shade.jackson.core.JsonFactory;
import org.nd4j.shade.jackson.core.JsonProcessingException;
import org.nd4j.shade.jackson.databind.DeserializationFeature;
import org.nd4j.shade.jackson.databind.MapperFeature;
import org.nd4j.shade.jackson.databind.ObjectMapper;
import org.nd4j.shade.jackson.databind.SerializationFeature;
import org.nd4j.shade.jackson.databind.annotation.JsonDeserialize;
import org.nd4j.shade.jackson.databind.annotation.JsonSerialize;
import org.nd4j.shade.jackson.dataformat.yaml.YAMLFactory;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

@JsonIgnoreProperties({"stats", "listeners", "iterationCount", "rng", "lastExportedRDDId", "lastRDDExportPath", "trainingMasterUID"})
/* loaded from: input_file:org/deeplearning4j/spark/impl/paramavg/ParameterAveragingTrainingMaster.class */
public class ParameterAveragingTrainingMaster implements TrainingMaster<ParameterAveragingTrainingResult, ParameterAveragingTrainingWorker> {
    private static final Logger log = LoggerFactory.getLogger(ParameterAveragingTrainingMaster.class);
    private static final int COALESCE_THRESHOLD = 3;
    private static ObjectMapper jsonMapper;
    private static ObjectMapper yamlMapper;
    private boolean saveUpdater;
    private Integer numWorkers;
    private int rddDataSetNumExamples;
    private int batchSizePerWorker;
    private int averagingFrequency;
    private int prefetchNumBatches;
    private boolean collectTrainingStats;
    private ParameterAveragingTrainingMasterStats.ParameterAveragingTrainingMasterStatsHelper stats;
    private int iterationCount;
    private Repartition repartition;
    private RepartitionStrategy repartitionStrategy;

    @JsonDeserialize(using = StorageLevelDeserializer.class)
    @JsonSerialize(using = StorageLevelSerializer.class)
    private StorageLevel storageLevel;

    @JsonDeserialize(using = StorageLevelDeserializer.class)
    @JsonSerialize(using = StorageLevelSerializer.class)
    private StorageLevel storageLevelStreams;
    private RDDTrainingApproach rddTrainingApproach;
    private String exportDirectory;
    private Random rng;
    private Collection<TrainingHook> trainingHookList;
    private int lastExportedRDDId;
    private String lastRDDExportPath;
    private final String trainingMasterUID;
    private Collection<IterationListener> listeners;
    private StatsStorageRouter statsStorage;

    /* loaded from: input_file:org/deeplearning4j/spark/impl/paramavg/ParameterAveragingTrainingMaster$Builder.class */
    public static class Builder {
        private boolean saveUpdater;
        private Integer numWorkers;
        private int rddDataSetNumExamples;
        private int batchSizePerWorker;
        private int averagingFrequency;
        private int prefetchNumBatches;
        private Repartition repartition;
        private RepartitionStrategy repartitionStrategy;
        private StorageLevel storageLevel;
        private StorageLevel storageLevelStreams;
        private RDDTrainingApproach rddTrainingApproach;
        private String exportDirectory;
        private Long rngSeed;
        private Collection<TrainingHook> trainingHooks;

        public Builder trainingHooks(Collection<TrainingHook> collection) {
            this.trainingHooks = collection;
            return this;
        }

        public Builder trainingHooks(TrainingHook... trainingHookArr) {
            this.trainingHooks = Arrays.asList(trainingHookArr);
            return this;
        }

        public Builder(int i) {
            this(null, i);
        }

        public Builder(Integer num, int i) {
            this.batchSizePerWorker = 16;
            this.averagingFrequency = 5;
            this.prefetchNumBatches = 0;
            this.repartition = Repartition.Always;
            this.repartitionStrategy = RepartitionStrategy.Balanced;
            this.storageLevel = StorageLevel.MEMORY_ONLY_SER();
            this.storageLevelStreams = StorageLevel.MEMORY_ONLY();
            this.rddTrainingApproach = RDDTrainingApproach.Export;
            this.exportDirectory = null;
            if (num != null && num.intValue() <= 0) {
                throw new IllegalArgumentException("Invalid number of workers: " + num + " (must be >= 1)");
            }
            if (i <= 0) {
                throw new IllegalArgumentException("Invalid rdd data set size: " + i + " (must be >= 1)");
            }
            this.numWorkers = num;
            this.rddDataSetNumExamples = i;
        }

        public Builder batchSizePerWorker(int i) {
            this.batchSizePerWorker = i;
            return this;
        }

        public Builder averagingFrequency(int i) {
            if (i <= 0) {
                throw new IllegalArgumentException("Invalid input: averaging frequency must be >= 1");
            }
            this.averagingFrequency = i;
            return this;
        }

        public Builder workerPrefetchNumBatches(int i) {
            this.prefetchNumBatches = i;
            return this;
        }

        public Builder saveUpdater(boolean z) {
            this.saveUpdater = z;
            return this;
        }

        public Builder repartionData(Repartition repartition) {
            this.repartition = repartition;
            return this;
        }

        public Builder repartitionStrategy(RepartitionStrategy repartitionStrategy) {
            this.repartitionStrategy = repartitionStrategy;
            return this;
        }

        public Builder storageLevel(StorageLevel storageLevel) {
            this.storageLevel = storageLevel;
            return this;
        }

        public Builder storageLevelStreams(StorageLevel storageLevel) {
            this.storageLevelStreams = storageLevel;
            return this;
        }

        public Builder rddTrainingApproach(RDDTrainingApproach rDDTrainingApproach) {
            this.rddTrainingApproach = rDDTrainingApproach;
            return this;
        }

        public Builder exportDirectory(String str) {
            this.exportDirectory = str;
            return this;
        }

        public Builder rngSeed(long j) {
            this.rngSeed = Long.valueOf(j);
            return this;
        }

        public ParameterAveragingTrainingMaster build() {
            return new ParameterAveragingTrainingMaster(this);
        }
    }

    private ParameterAveragingTrainingMaster() {
        this.iterationCount = 0;
        this.storageLevelStreams = StorageLevel.MEMORY_ONLY();
        this.rddTrainingApproach = RDDTrainingApproach.Export;
        this.exportDirectory = null;
        this.lastExportedRDDId = Integer.MIN_VALUE;
        String jvmuid = UIDProvider.getJVMUID();
        this.trainingMasterUID = System.currentTimeMillis() + "_" + (jvmuid.length() <= 8 ? jvmuid : jvmuid.substring(0, 8));
        this.rng = new Random();
    }

    private ParameterAveragingTrainingMaster(Builder builder) {
        this.iterationCount = 0;
        this.storageLevelStreams = StorageLevel.MEMORY_ONLY();
        this.rddTrainingApproach = RDDTrainingApproach.Export;
        this.exportDirectory = null;
        this.lastExportedRDDId = Integer.MIN_VALUE;
        this.saveUpdater = builder.saveUpdater;
        this.numWorkers = builder.numWorkers;
        this.rddDataSetNumExamples = builder.rddDataSetNumExamples;
        this.batchSizePerWorker = builder.batchSizePerWorker;
        this.averagingFrequency = builder.averagingFrequency;
        this.prefetchNumBatches = builder.prefetchNumBatches;
        this.repartition = builder.repartition;
        this.repartitionStrategy = builder.repartitionStrategy;
        this.storageLevel = builder.storageLevel;
        this.storageLevelStreams = builder.storageLevelStreams;
        this.rddTrainingApproach = builder.rddTrainingApproach;
        this.exportDirectory = builder.exportDirectory;
        this.trainingHookList = builder.trainingHooks;
        if (builder.rngSeed == null) {
            this.rng = new Random();
        } else {
            this.rng = new Random(builder.rngSeed.longValue());
        }
        String jvmuid = UIDProvider.getJVMUID();
        this.trainingMasterUID = System.currentTimeMillis() + "_" + (jvmuid.length() <= 8 ? jvmuid : jvmuid.substring(0, 8));
    }

    public ParameterAveragingTrainingMaster(boolean z, Integer num, int i, int i2, int i3, int i4) {
        this(z, num, i, i2, i3, i4, Repartition.Always, RepartitionStrategy.Balanced, false);
    }

    public ParameterAveragingTrainingMaster(boolean z, Integer num, int i, int i2, int i3, int i4, Repartition repartition, RepartitionStrategy repartitionStrategy, boolean z2) {
        this(z, num, i, i2, i3, i4, repartition, repartitionStrategy, StorageLevel.MEMORY_ONLY_SER(), z2);
    }

    public ParameterAveragingTrainingMaster(boolean z, Integer num, int i, int i2, int i3, int i4, Repartition repartition, RepartitionStrategy repartitionStrategy, StorageLevel storageLevel, boolean z2) {
        this.iterationCount = 0;
        this.storageLevelStreams = StorageLevel.MEMORY_ONLY();
        this.rddTrainingApproach = RDDTrainingApproach.Export;
        this.exportDirectory = null;
        this.lastExportedRDDId = Integer.MIN_VALUE;
        if (num.intValue() <= 0) {
            throw new IllegalArgumentException("Invalid number of workers: " + num + " (must be >= 1)");
        }
        if (i <= 0) {
            throw new IllegalArgumentException("Invalid rdd data set size: " + i + " (must be >= 1)");
        }
        this.saveUpdater = z;
        this.numWorkers = num;
        this.rddDataSetNumExamples = i;
        this.batchSizePerWorker = i2;
        this.averagingFrequency = i3;
        this.prefetchNumBatches = i4;
        this.collectTrainingStats = z2;
        this.repartition = repartition;
        this.repartitionStrategy = repartitionStrategy;
        this.storageLevel = storageLevel;
        if (z2) {
            this.stats = new ParameterAveragingTrainingMasterStats.ParameterAveragingTrainingMasterStatsHelper();
        }
        String jvmuid = UIDProvider.getJVMUID();
        this.trainingMasterUID = System.currentTimeMillis() + "_" + (jvmuid.length() <= 8 ? jvmuid : jvmuid.substring(0, 8));
        this.rng = new Random();
    }

    private static synchronized ObjectMapper getJsonMapper() {
        if (jsonMapper == null) {
            jsonMapper = getNewMapper(new JsonFactory());
        }
        return jsonMapper;
    }

    private static synchronized ObjectMapper getYamlMapper() {
        if (yamlMapper == null) {
            yamlMapper = getNewMapper(new YAMLFactory());
        }
        return yamlMapper;
    }

    private static ObjectMapper getNewMapper(JsonFactory jsonFactory) {
        ObjectMapper objectMapper = new ObjectMapper(jsonFactory);
        objectMapper.configure(DeserializationFeature.FAIL_ON_UNKNOWN_PROPERTIES, false);
        objectMapper.configure(SerializationFeature.FAIL_ON_EMPTY_BEANS, false);
        objectMapper.configure(MapperFeature.SORT_PROPERTIES_ALPHABETICALLY, true);
        objectMapper.enable(SerializationFeature.INDENT_OUTPUT);
        objectMapper.setVisibility(PropertyAccessor.ALL, JsonAutoDetect.Visibility.NONE);
        objectMapper.setVisibility(PropertyAccessor.FIELD, JsonAutoDetect.Visibility.ANY);
        return objectMapper;
    }

    @Override // org.deeplearning4j.spark.api.TrainingMaster
    public void removeHook(TrainingHook trainingHook) {
        if (this.trainingHookList == null) {
            return;
        }
        this.trainingHookList.remove(trainingHook);
    }

    @Override // org.deeplearning4j.spark.api.TrainingMaster
    public void addHook(TrainingHook trainingHook) {
        if (this.trainingHookList == null) {
            this.trainingHookList = new ArrayList();
        }
        this.trainingHookList.add(trainingHook);
    }

    @Override // org.deeplearning4j.spark.api.TrainingMaster
    public String toJson() {
        try {
            return getJsonMapper().writeValueAsString(this);
        } catch (JsonProcessingException e) {
            throw new RuntimeException("Error producing JSON representation for ParameterAveragingTrainingMaster", e);
        }
    }

    @Override // org.deeplearning4j.spark.api.TrainingMaster
    public String toYaml() {
        try {
            return getYamlMapper().writeValueAsString(this);
        } catch (JsonProcessingException e) {
            throw new RuntimeException("Error producing YAML representation for ParameterAveragingTrainingMaster", e);
        }
    }

    public static ParameterAveragingTrainingMaster fromJson(String str) {
        try {
            return (ParameterAveragingTrainingMaster) getJsonMapper().readValue(str, ParameterAveragingTrainingMaster.class);
        } catch (IOException e) {
            throw new RuntimeException("Could not parse JSON", e);
        }
    }

    public static ParameterAveragingTrainingMaster fromYaml(String str) {
        try {
            return (ParameterAveragingTrainingMaster) getYamlMapper().readValue(str, ParameterAveragingTrainingMaster.class);
        } catch (IOException e) {
            throw new RuntimeException("Could not parse YAML", e);
        }
    }

    /* JADX WARN: Can't rename method to resolve collision */
    @Override // org.deeplearning4j.spark.api.TrainingMaster
    public ParameterAveragingTrainingWorker getWorkerInstance(SparkDl4jMultiLayer sparkDl4jMultiLayer) {
        NetBroadcastTuple netBroadcastTuple = new NetBroadcastTuple(sparkDl4jMultiLayer.getNetwork().getLayerWiseConfigurations(), sparkDl4jMultiLayer.getNetwork().params(), sparkDl4jMultiLayer.getNetwork().getUpdater().getStateViewArray());
        if (this.collectTrainingStats) {
            this.stats.logBroadcastStart();
        }
        Broadcast broadcast = sparkDl4jMultiLayer.getSparkContext().broadcast(netBroadcastTuple);
        if (this.collectTrainingStats) {
            this.stats.logBroadcastEnd();
        }
        return new ParameterAveragingTrainingWorker(broadcast, this.saveUpdater, new WorkerConfiguration(false, this.rddDataSetNumExamples, this.batchSizePerWorker, this.averagingFrequency, this.prefetchNumBatches, this.collectTrainingStats), this.trainingHookList, this.listeners, getRouterProvider());
    }

    /* JADX WARN: Can't rename method to resolve collision */
    @Override // org.deeplearning4j.spark.api.TrainingMaster
    public ParameterAveragingTrainingWorker getWorkerInstance(SparkComputationGraph sparkComputationGraph) {
        NetBroadcastTuple netBroadcastTuple = new NetBroadcastTuple(sparkComputationGraph.getNetwork().getConfiguration(), sparkComputationGraph.getNetwork().params(), sparkComputationGraph.getNetwork().getUpdater().getStateViewArray());
        if (this.collectTrainingStats) {
            this.stats.logBroadcastStart();
        }
        Broadcast broadcast = sparkComputationGraph.getSparkContext().broadcast(netBroadcastTuple);
        if (this.collectTrainingStats) {
            this.stats.logBroadcastEnd();
        }
        return new ParameterAveragingTrainingWorker(broadcast, this.saveUpdater, new WorkerConfiguration(true, this.rddDataSetNumExamples, this.batchSizePerWorker, this.averagingFrequency, this.prefetchNumBatches, this.collectTrainingStats), this.trainingHookList, this.listeners, getRouterProvider());
    }

    private int numObjectsEachWorker(int i) {
        return (this.batchSizePerWorker * this.averagingFrequency) / i;
    }

    private int getNumDataSetObjectsPerSplit(int i) {
        int intValue;
        if (i == 1) {
            intValue = this.numWorkers.intValue() * this.batchSizePerWorker * this.averagingFrequency;
        } else {
            int numObjectsEachWorker = numObjectsEachWorker(i);
            if (numObjectsEachWorker < 1) {
                numObjectsEachWorker = 1;
            }
            intValue = numObjectsEachWorker * this.numWorkers.intValue();
        }
        return intValue;
    }

    @Override // org.deeplearning4j.spark.api.TrainingMaster
    public void executeTraining(SparkDl4jMultiLayer sparkDl4jMultiLayer, JavaRDD<DataSet> javaRDD) {
        if (this.numWorkers == null) {
            this.numWorkers = sparkDl4jMultiLayer.getSparkContext().defaultParallelism();
        }
        if (this.rddTrainingApproach == RDDTrainingApproach.Direct) {
            executeTrainingDirect(sparkDl4jMultiLayer, javaRDD);
        } else {
            executeTrainingPathsHelper(sparkDl4jMultiLayer, exportIfRequired(sparkDl4jMultiLayer.getSparkContext(), javaRDD), this.batchSizePerWorker);
        }
    }

    private <T, Repr extends JavaRDDLike<T, Repr>> long getTotalDataSetObjectCount(JavaRDDLike<T, Repr> javaRDDLike) {
        if (this.collectTrainingStats) {
            this.stats.logCountStart();
        }
        long count = javaRDDLike.count();
        if (this.collectTrainingStats) {
            this.stats.logCountEnd();
        }
        return count;
    }

    private <T, Repr> JavaPairRDD<T, Repr>[] getSplitRDDs(JavaPairRDD<T, Repr> javaPairRDD, int i) {
        int numDataSetObjectsPerSplit = getNumDataSetObjectsPerSplit(this.rddDataSetNumExamples);
        if (this.collectTrainingStats) {
            this.stats.logSplitStart();
        }
        JavaPairRDD<T, Repr>[] balancedRandomSplit = SparkUtils.balancedRandomSplit(i, numDataSetObjectsPerSplit, javaPairRDD, this.rng.nextLong());
        if (this.collectTrainingStats) {
            this.stats.logSplitEnd();
        }
        return balancedRandomSplit;
    }

    private <T> JavaRDD<T>[] getSplitRDDs(JavaRDD<T> javaRDD, int i, int i2) {
        int numDataSetObjectsPerSplit = getNumDataSetObjectsPerSplit(i2);
        if (this.collectTrainingStats) {
            this.stats.logSplitStart();
        }
        JavaRDD<T>[] balancedRandomSplit = SparkUtils.balancedRandomSplit(i, numDataSetObjectsPerSplit, javaRDD, this.rng.nextLong());
        if (this.collectTrainingStats) {
            this.stats.logSplitEnd();
        }
        return balancedRandomSplit;
    }

    private void executeTrainingDirect(SparkDl4jMultiLayer sparkDl4jMultiLayer, JavaRDD<DataSet> javaRDD) {
        if (this.collectTrainingStats) {
            this.stats.logFitStart();
        }
        if (this.storageLevel != null) {
            javaRDD.persist(this.storageLevel);
        }
        long totalDataSetObjectCount = getTotalDataSetObjectCount(javaRDD);
        JavaRDD<DataSet>[] splitRDDs = getSplitRDDs(javaRDD, (int) totalDataSetObjectCount, this.rddDataSetNumExamples);
        int i = 1;
        for (JavaRDD<DataSet> javaRDD2 : splitRDDs) {
            int i2 = i;
            i++;
            doIteration(sparkDl4jMultiLayer, javaRDD2, i2, splitRDDs.length);
        }
        if (this.collectTrainingStats) {
            this.stats.logFitEnd((int) totalDataSetObjectCount);
        }
    }

    @Override // org.deeplearning4j.spark.api.TrainingMaster
    @Deprecated
    public void executeTraining(SparkDl4jMultiLayer sparkDl4jMultiLayer, JavaPairRDD<String, PortableDataStream> javaPairRDD) {
        if (this.numWorkers == null) {
            this.numWorkers = sparkDl4jMultiLayer.getSparkContext().defaultParallelism();
        }
        if (this.collectTrainingStats) {
            this.stats.logFitStart();
        }
        int size = javaPairRDD.partitions().size();
        if (size >= COALESCE_THRESHOLD * this.numWorkers.intValue()) {
            log.info("Coalescing PortableDataStreams from {} to {} partitions", Integer.valueOf(size), this.numWorkers);
            javaPairRDD = javaPairRDD.coalesce(this.numWorkers.intValue());
        }
        if (this.storageLevelStreams != null) {
            javaPairRDD.persist(this.storageLevelStreams);
        }
        long totalDataSetObjectCount = getTotalDataSetObjectCount(javaPairRDD);
        JavaPairRDD[] splitRDDs = getSplitRDDs(javaPairRDD, (int) totalDataSetObjectCount);
        int i = 1;
        for (JavaPairRDD javaPairRDD2 : splitRDDs) {
            int i2 = i;
            i++;
            doIterationPDS(sparkDl4jMultiLayer, null, javaPairRDD2.values(), i2, splitRDDs.length);
        }
        if (this.collectTrainingStats) {
            this.stats.logFitEnd((int) totalDataSetObjectCount);
        }
    }

    @Override // org.deeplearning4j.spark.api.TrainingMaster
    public void executeTrainingPaths(SparkDl4jMultiLayer sparkDl4jMultiLayer, JavaRDD<String> javaRDD) {
        executeTrainingPathsHelper(sparkDl4jMultiLayer, javaRDD, this.rddDataSetNumExamples);
    }

    private void executeTrainingPathsHelper(SparkDl4jMultiLayer sparkDl4jMultiLayer, JavaRDD<String> javaRDD, int i) {
        if (this.numWorkers == null) {
            this.numWorkers = sparkDl4jMultiLayer.getSparkContext().defaultParallelism();
        }
        if (this.collectTrainingStats) {
            this.stats.logFitStart();
        }
        if (this.storageLevelStreams != null) {
            javaRDD.persist(this.storageLevelStreams);
        }
        long totalDataSetObjectCount = getTotalDataSetObjectCount(javaRDD);
        JavaRDD<String>[] splitRDDs = getSplitRDDs(javaRDD, (int) totalDataSetObjectCount, i);
        int i2 = 1;
        for (JavaRDD<String> javaRDD2 : splitRDDs) {
            int i3 = i2;
            i2++;
            doIterationPaths(sparkDl4jMultiLayer, null, javaRDD2, i3, splitRDDs.length, i);
        }
        if (this.collectTrainingStats) {
            this.stats.logFitEnd((int) totalDataSetObjectCount);
        }
    }

    @Override // org.deeplearning4j.spark.api.TrainingMaster
    public void executeTraining(SparkComputationGraph sparkComputationGraph, JavaRDD<DataSet> javaRDD) {
        if (this.numWorkers == null) {
            this.numWorkers = sparkComputationGraph.getSparkContext().defaultParallelism();
        }
        executeTrainingMDS(sparkComputationGraph, javaRDD.map(new DataSetToMultiDataSetFn()));
    }

    @Override // org.deeplearning4j.spark.api.TrainingMaster
    public void executeTrainingMDS(SparkComputationGraph sparkComputationGraph, JavaRDD<MultiDataSet> javaRDD) {
        if (this.numWorkers == null) {
            this.numWorkers = sparkComputationGraph.getSparkContext().defaultParallelism();
        }
        if (this.rddTrainingApproach == RDDTrainingApproach.Direct) {
            executeTrainingDirect(sparkComputationGraph, javaRDD);
        } else {
            executeTrainingPathsMDSHelper(sparkComputationGraph, exportIfRequiredMDS(sparkComputationGraph.getSparkContext(), javaRDD), this.batchSizePerWorker);
        }
    }

    private void executeTrainingDirect(SparkComputationGraph sparkComputationGraph, JavaRDD<MultiDataSet> javaRDD) {
        if (this.collectTrainingStats) {
            this.stats.logFitStart();
        }
        if (this.storageLevel != null) {
            javaRDD.persist(this.storageLevel);
        }
        long totalDataSetObjectCount = getTotalDataSetObjectCount(javaRDD);
        JavaRDD<MultiDataSet>[] splitRDDs = getSplitRDDs(javaRDD, (int) totalDataSetObjectCount, this.rddDataSetNumExamples);
        int i = 1;
        for (JavaRDD<MultiDataSet> javaRDD2 : splitRDDs) {
            int i2 = i;
            i++;
            doIteration(sparkComputationGraph, javaRDD2, i2, splitRDDs.length);
        }
        if (this.collectTrainingStats) {
            this.stats.logFitEnd((int) totalDataSetObjectCount);
        }
    }

    @Override // org.deeplearning4j.spark.api.TrainingMaster
    public void executeTraining(SparkComputationGraph sparkComputationGraph, JavaPairRDD<String, PortableDataStream> javaPairRDD) {
        if (this.numWorkers == null) {
            this.numWorkers = sparkComputationGraph.getSparkContext().defaultParallelism();
        }
        if (this.collectTrainingStats) {
            this.stats.logFitStart();
        }
        int size = javaPairRDD.partitions().size();
        if (size >= COALESCE_THRESHOLD * this.numWorkers.intValue()) {
            log.info("Coalescing streams from {} to {} partitions", Integer.valueOf(size), this.numWorkers);
            javaPairRDD = javaPairRDD.coalesce(this.numWorkers.intValue());
        }
        if (this.storageLevelStreams != null) {
            javaPairRDD.persist(this.storageLevelStreams);
        }
        long totalDataSetObjectCount = getTotalDataSetObjectCount(javaPairRDD);
        JavaPairRDD[] splitRDDs = getSplitRDDs(javaPairRDD, (int) totalDataSetObjectCount);
        int i = 1;
        for (JavaPairRDD javaPairRDD2 : splitRDDs) {
            int i2 = i;
            i++;
            doIterationPDS(null, sparkComputationGraph, javaPairRDD2.values(), i2, splitRDDs.length);
        }
        if (this.collectTrainingStats) {
            this.stats.logFitEnd((int) totalDataSetObjectCount);
        }
    }

    @Override // org.deeplearning4j.spark.api.TrainingMaster
    public void executeTrainingMDS(SparkComputationGraph sparkComputationGraph, JavaPairRDD<String, PortableDataStream> javaPairRDD) {
        if (this.numWorkers == null) {
            this.numWorkers = sparkComputationGraph.getSparkContext().defaultParallelism();
        }
        if (this.collectTrainingStats) {
            this.stats.logFitStart();
        }
        if (this.storageLevelStreams != null) {
            javaPairRDD.persist(this.storageLevelStreams);
        }
        long totalDataSetObjectCount = getTotalDataSetObjectCount(javaPairRDD);
        JavaPairRDD[] splitRDDs = getSplitRDDs(javaPairRDD, (int) totalDataSetObjectCount);
        int i = 1;
        for (JavaPairRDD javaPairRDD2 : splitRDDs) {
            JavaRDD values = javaPairRDD2.values();
            if (this.collectTrainingStats) {
                this.stats.logRepartitionStart();
            }
            JavaRDD<PortableDataStream> repartition = SparkUtils.repartition(values, this.repartition, this.repartitionStrategy, numObjectsEachWorker(this.rddDataSetNumExamples), this.numWorkers.intValue());
            if (this.collectTrainingStats && this.repartition != Repartition.Never) {
                this.stats.logRepartitionEnd();
            }
            int i2 = i;
            i++;
            doIterationPDS_MDS(sparkComputationGraph, repartition, i2, splitRDDs.length);
        }
        if (this.collectTrainingStats) {
            this.stats.logFitEnd((int) totalDataSetObjectCount);
        }
    }

    @Override // org.deeplearning4j.spark.api.TrainingMaster
    public void executeTrainingPaths(SparkComputationGraph sparkComputationGraph, JavaRDD<String> javaRDD) {
        if (this.numWorkers == null) {
            this.numWorkers = sparkComputationGraph.getSparkContext().defaultParallelism();
        }
        if (this.collectTrainingStats) {
            this.stats.logFitStart();
        }
        if (this.storageLevelStreams != null) {
            javaRDD.persist(this.storageLevelStreams);
        }
        long totalDataSetObjectCount = getTotalDataSetObjectCount(javaRDD);
        JavaRDD<String>[] splitRDDs = getSplitRDDs(javaRDD, (int) totalDataSetObjectCount, this.rddDataSetNumExamples);
        int i = 1;
        for (JavaRDD<String> javaRDD2 : splitRDDs) {
            int i2 = i;
            i++;
            doIterationPaths(null, sparkComputationGraph, javaRDD2, i2, splitRDDs.length, this.rddDataSetNumExamples);
        }
        if (this.collectTrainingStats) {
            this.stats.logFitEnd((int) totalDataSetObjectCount);
        }
    }

    @Override // org.deeplearning4j.spark.api.TrainingMaster
    public void executeTrainingPathsMDS(SparkComputationGraph sparkComputationGraph, JavaRDD<String> javaRDD) {
        executeTrainingPathsMDSHelper(sparkComputationGraph, javaRDD, this.rddDataSetNumExamples);
    }

    private void executeTrainingPathsMDSHelper(SparkComputationGraph sparkComputationGraph, JavaRDD<String> javaRDD, int i) {
        if (this.numWorkers == null) {
            this.numWorkers = sparkComputationGraph.getSparkContext().defaultParallelism();
        }
        if (this.collectTrainingStats) {
            this.stats.logFitStart();
        }
        if (this.storageLevelStreams != null) {
            javaRDD.persist(this.storageLevelStreams);
        }
        long totalDataSetObjectCount = getTotalDataSetObjectCount(javaRDD);
        JavaRDD<String>[] splitRDDs = getSplitRDDs(javaRDD, (int) totalDataSetObjectCount, i);
        int i2 = 1;
        for (JavaRDD<String> javaRDD2 : splitRDDs) {
            int i3 = i2;
            i2++;
            doIterationPathsMDS(sparkComputationGraph, javaRDD2, i3, splitRDDs.length, i);
        }
        if (this.collectTrainingStats) {
            this.stats.logFitEnd((int) totalDataSetObjectCount);
        }
    }

    @Override // org.deeplearning4j.spark.api.TrainingMaster
    public void setCollectTrainingStats(boolean z) {
        this.collectTrainingStats = z;
        if (!z) {
            this.stats = null;
        } else if (this.stats == null) {
            this.stats = new ParameterAveragingTrainingMasterStats.ParameterAveragingTrainingMasterStatsHelper();
        }
    }

    @Override // org.deeplearning4j.spark.api.TrainingMaster
    public boolean getIsCollectTrainingStats() {
        return this.collectTrainingStats;
    }

    @Override // org.deeplearning4j.spark.api.TrainingMaster
    public SparkTrainingStats getTrainingStats() {
        if (this.stats != null) {
            return this.stats.build();
        }
        return null;
    }

    @Override // org.deeplearning4j.spark.api.TrainingMaster
    public void setListeners(Collection<IterationListener> collection) {
        setListeners(null, collection);
    }

    @Override // org.deeplearning4j.spark.api.TrainingMaster
    public void setListeners(StatsStorageRouter statsStorageRouter, Collection<IterationListener> collection) {
        this.statsStorage = statsStorageRouter;
        this.listeners = collection;
    }

    @Override // org.deeplearning4j.spark.api.TrainingMaster
    public boolean deleteTempFiles(JavaSparkContext javaSparkContext) {
        return this.lastRDDExportPath == null || deleteTempDir(javaSparkContext, this.lastRDDExportPath);
    }

    @Override // org.deeplearning4j.spark.api.TrainingMaster
    public boolean deleteTempFiles(SparkContext sparkContext) {
        return deleteTempFiles(new JavaSparkContext(sparkContext));
    }

    private void doIteration(SparkDl4jMultiLayer sparkDl4jMultiLayer, JavaRDD<DataSet> javaRDD, int i, int i2) {
        log.info("Starting training of split {} of {}. workerMiniBatchSize={}, averagingFreq={}, Configured for {} workers", new Object[]{Integer.valueOf(i), Integer.valueOf(i2), Integer.valueOf(this.batchSizePerWorker), Integer.valueOf(this.averagingFrequency), this.numWorkers});
        if (this.collectTrainingStats) {
            this.stats.logMapPartitionsStart();
        }
        if (this.collectTrainingStats) {
            this.stats.logRepartitionStart();
        }
        JavaRDD repartition = SparkUtils.repartition(javaRDD, this.repartition, this.repartitionStrategy, numObjectsEachWorker(this.rddDataSetNumExamples), this.numWorkers.intValue());
        int size = repartition.partitions().size();
        if (this.collectTrainingStats && this.repartition != Repartition.Never) {
            this.stats.logRepartitionEnd();
        }
        processResults(sparkDl4jMultiLayer, null, repartition.mapPartitions(new ExecuteWorkerFlatMap(getWorkerInstance(sparkDl4jMultiLayer))), i, i2);
        if (this.collectTrainingStats) {
            this.stats.logMapPartitionsEnd(size);
        }
    }

    private void doIterationPDS(SparkDl4jMultiLayer sparkDl4jMultiLayer, SparkComputationGraph sparkComputationGraph, JavaRDD<PortableDataStream> javaRDD, int i, int i2) {
        log.info("Starting training of split {} of {}. workerMiniBatchSize={}, averagingFreq={}, Configured for {} workers", new Object[]{Integer.valueOf(i), Integer.valueOf(i2), Integer.valueOf(this.batchSizePerWorker), Integer.valueOf(this.averagingFrequency), this.numWorkers});
        if (this.collectTrainingStats) {
            this.stats.logMapPartitionsStart();
        }
        if (this.collectTrainingStats) {
            this.stats.logRepartitionStart();
        }
        JavaRDD repartition = SparkUtils.repartition(javaRDD, this.repartition, this.repartitionStrategy, numObjectsEachWorker(this.rddDataSetNumExamples), this.numWorkers.intValue());
        int size = repartition.partitions().size();
        if (this.collectTrainingStats && this.repartition != Repartition.Never) {
            this.stats.logRepartitionEnd();
        }
        processResults(sparkDl4jMultiLayer, sparkComputationGraph, repartition.mapPartitions(sparkDl4jMultiLayer != null ? new ExecuteWorkerPDSFlatMap(getWorkerInstance(sparkDl4jMultiLayer)) : new ExecuteWorkerPDSFlatMap(getWorkerInstance(sparkComputationGraph))), i, i2);
        if (this.collectTrainingStats) {
            this.stats.logMapPartitionsEnd(size);
        }
    }

    private void doIterationPaths(SparkDl4jMultiLayer sparkDl4jMultiLayer, SparkComputationGraph sparkComputationGraph, JavaRDD<String> javaRDD, int i, int i2, int i3) {
        log.info("Starting training of split {} of {}. workerMiniBatchSize={}, averagingFreq={}, Configured for {} workers", new Object[]{Integer.valueOf(i), Integer.valueOf(i2), Integer.valueOf(this.batchSizePerWorker), Integer.valueOf(this.averagingFrequency), this.numWorkers});
        if (this.collectTrainingStats) {
            this.stats.logMapPartitionsStart();
        }
        if (this.collectTrainingStats) {
            this.stats.logRepartitionStart();
        }
        JavaRDD repartition = SparkUtils.repartition(javaRDD, this.repartition, this.repartitionStrategy, numObjectsEachWorker(i3), this.numWorkers.intValue());
        int size = repartition.partitions().size();
        if (this.collectTrainingStats && this.repartition != Repartition.Never) {
            this.stats.logRepartitionEnd();
        }
        processResults(sparkDl4jMultiLayer, sparkComputationGraph, repartition.mapPartitions(sparkDl4jMultiLayer != null ? new ExecuteWorkerPathFlatMap(getWorkerInstance(sparkDl4jMultiLayer)) : new ExecuteWorkerPathFlatMap(getWorkerInstance(sparkComputationGraph))), i, i2);
        if (this.collectTrainingStats) {
            this.stats.logMapPartitionsEnd(size);
        }
    }

    private void doIterationPathsMDS(SparkComputationGraph sparkComputationGraph, JavaRDD<String> javaRDD, int i, int i2, int i3) {
        log.info("Starting training of split {} of {}. workerMiniBatchSize={}, averagingFreq={}, Configured for {} workers", new Object[]{Integer.valueOf(i), Integer.valueOf(i2), Integer.valueOf(this.batchSizePerWorker), Integer.valueOf(this.averagingFrequency), this.numWorkers});
        if (this.collectTrainingStats) {
            this.stats.logMapPartitionsStart();
        }
        if (this.collectTrainingStats) {
            this.stats.logRepartitionStart();
        }
        JavaRDD repartition = SparkUtils.repartition(javaRDD, this.repartition, this.repartitionStrategy, numObjectsEachWorker(i3), this.numWorkers.intValue());
        int size = repartition.partitions().size();
        if (this.collectTrainingStats && this.repartition != Repartition.Never) {
            this.stats.logRepartitionEnd();
        }
        processResults(null, sparkComputationGraph, repartition.mapPartitions(new ExecuteWorkerPathMDSFlatMap(getWorkerInstance(sparkComputationGraph))), i, i2);
        if (this.collectTrainingStats) {
            this.stats.logMapPartitionsEnd(size);
        }
    }

    private void doIteration(SparkComputationGraph sparkComputationGraph, JavaRDD<MultiDataSet> javaRDD, int i, int i2) {
        log.info("Starting training of split {} of {}. workerMiniBatchSize={}, averagingFreq={}, Configured for {} workers", new Object[]{Integer.valueOf(i), Integer.valueOf(i2), Integer.valueOf(this.batchSizePerWorker), Integer.valueOf(this.averagingFrequency), this.numWorkers});
        if (this.collectTrainingStats) {
            this.stats.logMapPartitionsStart();
        }
        JavaRDD repartition = SparkUtils.repartition(javaRDD, this.repartition, this.repartitionStrategy, numObjectsEachWorker(this.rddDataSetNumExamples), this.numWorkers.intValue());
        int size = javaRDD.partitions().size();
        processResults(null, sparkComputationGraph, repartition.mapPartitions(new ExecuteWorkerMultiDataSetFlatMap(getWorkerInstance(sparkComputationGraph))), i, i2);
        if (this.collectTrainingStats) {
            this.stats.logMapPartitionsEnd(size);
        }
    }

    private void doIterationPDS_MDS(SparkComputationGraph sparkComputationGraph, JavaRDD<PortableDataStream> javaRDD, int i, int i2) {
        log.info("Starting training of split {} of {}. workerMiniBatchSize={}, averagingFreq={}, Configured for {} workers", new Object[]{Integer.valueOf(i), Integer.valueOf(i2), Integer.valueOf(this.batchSizePerWorker), Integer.valueOf(this.averagingFrequency), this.numWorkers});
        if (this.collectTrainingStats) {
            this.stats.logMapPartitionsStart();
        }
        if (this.collectTrainingStats) {
            this.stats.logRepartitionStart();
        }
        JavaRDD repartition = SparkUtils.repartition(javaRDD, this.repartition, this.repartitionStrategy, numObjectsEachWorker(this.rddDataSetNumExamples), this.numWorkers.intValue());
        int size = repartition.partitions().size();
        if (this.collectTrainingStats && this.repartition != Repartition.Never) {
            this.stats.logRepartitionEnd();
        }
        processResults(null, sparkComputationGraph, repartition.mapPartitions(new ExecuteWorkerPDSMDSFlatMap(getWorkerInstance(sparkComputationGraph))), i, i2);
        if (this.collectTrainingStats) {
            this.stats.logMapPartitionsEnd(size);
        }
    }

    private void processResults(SparkDl4jMultiLayer sparkDl4jMultiLayer, SparkComputationGraph sparkComputationGraph, JavaRDD<ParameterAveragingTrainingResult> javaRDD, int i, int i2) {
        if (this.collectTrainingStats) {
            this.stats.logAggregateStartTime();
        }
        ParameterAveragingAggregationTuple parameterAveragingAggregationTuple = (ParameterAveragingAggregationTuple) javaRDD.aggregate((Object) null, new ParameterAveragingElementAddFunction(), new ParameterAveragingElementCombineFunction());
        INDArray parametersSum = parameterAveragingAggregationTuple.getParametersSum();
        int aggregationsCount = parameterAveragingAggregationTuple.getAggregationsCount();
        SparkTrainingStats sparkTrainingStats = parameterAveragingAggregationTuple.getSparkTrainingStats();
        if (this.collectTrainingStats) {
            this.stats.logAggregationEndTime();
        }
        if (this.collectTrainingStats) {
            this.stats.logProcessParamsUpdaterStart();
        }
        if (parametersSum != null) {
            parametersSum.divi(Integer.valueOf(aggregationsCount));
            INDArray updaterStateSum = parameterAveragingAggregationTuple.getUpdaterStateSum();
            if (updaterStateSum != null) {
                updaterStateSum.divi(Integer.valueOf(aggregationsCount));
            }
            if (sparkDl4jMultiLayer != null) {
                MultiLayerNetwork network = sparkDl4jMultiLayer.getNetwork();
                network.setParameters(parametersSum);
                if (updaterStateSum != null) {
                    network.getUpdater().setStateViewArray((Layer) null, updaterStateSum, false);
                }
                sparkDl4jMultiLayer.setScore(parameterAveragingAggregationTuple.getScoreSum() / parameterAveragingAggregationTuple.getAggregationsCount());
            } else {
                ComputationGraph network2 = sparkComputationGraph.getNetwork();
                network2.setParams(parametersSum);
                if (updaterStateSum != null) {
                    network2.getUpdater().setStateViewArray(updaterStateSum);
                }
                sparkComputationGraph.setScore(parameterAveragingAggregationTuple.getScoreSum() / parameterAveragingAggregationTuple.getAggregationsCount());
            }
        } else {
            log.info("Skipping imbalanced split with no data for all executors");
        }
        if (this.collectTrainingStats) {
            this.stats.logProcessParamsUpdaterEnd();
            this.stats.addWorkerStats(sparkTrainingStats);
        }
        if (this.statsStorage != null) {
            Collection<StorageMetaData> listenerMetaData = parameterAveragingAggregationTuple.getListenerMetaData();
            if (listenerMetaData != null && listenerMetaData.size() > 0) {
                this.statsStorage.putStorageMetaData(listenerMetaData);
            }
            Collection<Persistable> listenerStaticInfo = parameterAveragingAggregationTuple.getListenerStaticInfo();
            if (listenerStaticInfo != null && listenerStaticInfo.size() > 0) {
                this.statsStorage.putStaticInfo(listenerStaticInfo);
            }
            Collection<Persistable> listenerUpdates = parameterAveragingAggregationTuple.getListenerUpdates();
            if (listenerUpdates != null && listenerUpdates.size() > 0) {
                this.statsStorage.putUpdate(listenerUpdates);
            }
        }
        if (Nd4j.getExecutioner() instanceof GridExecutioner) {
            Nd4j.getExecutioner().flushQueueBlocking();
        }
        log.info("Completed training of split {} of {}", Integer.valueOf(i), Integer.valueOf(i2));
        if (parametersSum != null) {
            if (sparkDl4jMultiLayer != null) {
                MultiLayerConfiguration layerWiseConfigurations = sparkDl4jMultiLayer.getNetwork().getLayerWiseConfigurations();
                layerWiseConfigurations.setIterationCount(layerWiseConfigurations.getIterationCount() + (sparkDl4jMultiLayer.getNetwork().conf().getNumIterations() * this.averagingFrequency));
            } else {
                ComputationGraphConfiguration configuration = sparkComputationGraph.getNetwork().getConfiguration();
                configuration.setIterationCount(configuration.getIterationCount() + (sparkComputationGraph.getNetwork().conf().getNumIterations() * this.averagingFrequency));
            }
        }
    }

    private JavaRDD<String> exportIfRequired(JavaSparkContext javaSparkContext, JavaRDD<DataSet> javaRDD) {
        String export;
        ExportSupport.assertExportSupported(javaSparkContext);
        if (this.collectTrainingStats) {
            this.stats.logExportStart();
        }
        int id = javaRDD.id();
        if (this.lastExportedRDDId == Integer.MIN_VALUE) {
            export = export(javaRDD);
        } else if (this.lastExportedRDDId == id) {
            export = getBaseDirForRDD(javaRDD);
        } else {
            deleteTempDir(javaSparkContext, this.lastRDDExportPath);
            export = export(javaRDD);
        }
        if (this.collectTrainingStats) {
            this.stats.logExportEnd();
        }
        return javaSparkContext.textFile(export + "paths/");
    }

    private JavaRDD<String> exportIfRequiredMDS(JavaSparkContext javaSparkContext, JavaRDD<MultiDataSet> javaRDD) {
        String exportMDS;
        ExportSupport.assertExportSupported(javaSparkContext);
        if (this.collectTrainingStats) {
            this.stats.logExportStart();
        }
        int id = javaRDD.id();
        if (this.lastExportedRDDId == Integer.MIN_VALUE) {
            exportMDS = exportMDS(javaRDD);
        } else if (this.lastExportedRDDId == id) {
            exportMDS = getBaseDirForRDD(javaRDD);
        } else {
            deleteTempDir(javaSparkContext, this.lastRDDExportPath);
            exportMDS = exportMDS(javaRDD);
        }
        if (this.collectTrainingStats) {
            this.stats.logExportEnd();
        }
        return javaSparkContext.textFile(exportMDS + "paths/");
    }

    private String export(JavaRDD<DataSet> javaRDD) {
        String baseDirForRDD = getBaseDirForRDD(javaRDD);
        String str = baseDirForRDD + "data/";
        log.info("Initiating RDD<DataSet> export at {}", baseDirForRDD);
        javaRDD.mapPartitionsWithIndex(new BatchAndExportDataSetsFunction(this.batchSizePerWorker, str), true).saveAsTextFile(baseDirForRDD + "paths/");
        log.info("RDD<DataSet> export complete at {}", baseDirForRDD);
        this.lastExportedRDDId = javaRDD.id();
        this.lastRDDExportPath = baseDirForRDD;
        return baseDirForRDD;
    }

    private String exportMDS(JavaRDD<MultiDataSet> javaRDD) {
        String baseDirForRDD = getBaseDirForRDD(javaRDD);
        String str = baseDirForRDD + "data/";
        log.info("Initiating RDD<MultiDataSet> export at {}", baseDirForRDD);
        javaRDD.mapPartitionsWithIndex(new BatchAndExportMultiDataSetsFunction(this.batchSizePerWorker, str), true).saveAsTextFile(baseDirForRDD + "paths/");
        log.info("RDD<MultiDataSet> export complete at {}", baseDirForRDD);
        this.lastExportedRDDId = javaRDD.id();
        this.lastRDDExportPath = baseDirForRDD;
        return baseDirForRDD;
    }

    private String getBaseDirForRDD(JavaRDD<?> javaRDD) {
        if (this.exportDirectory == null) {
            this.exportDirectory = getDefaultExportDirectory(javaRDD.context());
        }
        return this.exportDirectory + (this.exportDirectory.endsWith("/") ? "" : "/") + this.trainingMasterUID + "/" + javaRDD.id() + "/";
    }

    private boolean deleteTempDir(JavaSparkContext javaSparkContext, String str) {
        log.info("Attempting to delete temporary directory: {}", str);
        try {
            try {
                FileSystem.get(new URI(str), javaSparkContext.hadoopConfiguration()).delete(new Path(str), true);
                log.info("Deleted temporary directory: {}", str);
                return true;
            } catch (IOException e) {
                log.warn("Could not delete temporary directory: {}", str, e);
                return false;
            }
        } catch (IOException | URISyntaxException e2) {
            throw new RuntimeException(e2);
        }
    }

    private String getDefaultExportDirectory(SparkContext sparkContext) {
        String str = sparkContext.hadoopConfiguration().get("hadoop.tmp.dir");
        if (!str.endsWith("/") && !str.endsWith("\\")) {
            str = str + "/";
        }
        return str + "dl4j/";
    }

    private StatsStorageRouterProvider getRouterProvider() {
        if (this.statsStorage == null) {
            return null;
        }
        return new VanillaStatsStorageRouterProvider();
    }

    public boolean isSaveUpdater() {
        return this.saveUpdater;
    }

    public Integer getNumWorkers() {
        return this.numWorkers;
    }

    public int getRddDataSetNumExamples() {
        return this.rddDataSetNumExamples;
    }

    public int getBatchSizePerWorker() {
        return this.batchSizePerWorker;
    }

    public int getAveragingFrequency() {
        return this.averagingFrequency;
    }

    public int getPrefetchNumBatches() {
        return this.prefetchNumBatches;
    }

    public boolean isCollectTrainingStats() {
        return this.collectTrainingStats;
    }

    public ParameterAveragingTrainingMasterStats.ParameterAveragingTrainingMasterStatsHelper getStats() {
        return this.stats;
    }

    public int getIterationCount() {
        return this.iterationCount;
    }

    public Repartition getRepartition() {
        return this.repartition;
    }

    public RepartitionStrategy getRepartitionStrategy() {
        return this.repartitionStrategy;
    }

    public StorageLevel getStorageLevel() {
        return this.storageLevel;
    }

    public StorageLevel getStorageLevelStreams() {
        return this.storageLevelStreams;
    }

    public RDDTrainingApproach getRddTrainingApproach() {
        return this.rddTrainingApproach;
    }

    public String getExportDirectory() {
        return this.exportDirectory;
    }

    public Random getRng() {
        return this.rng;
    }

    public Collection<TrainingHook> getTrainingHookList() {
        return this.trainingHookList;
    }

    public int getLastExportedRDDId() {
        return this.lastExportedRDDId;
    }

    public String getLastRDDExportPath() {
        return this.lastRDDExportPath;
    }

    public String getTrainingMasterUID() {
        return this.trainingMasterUID;
    }

    public Collection<IterationListener> getListeners() {
        return this.listeners;
    }

    public StatsStorageRouter getStatsStorage() {
        return this.statsStorage;
    }

    public void setSaveUpdater(boolean z) {
        this.saveUpdater = z;
    }

    public void setNumWorkers(Integer num) {
        this.numWorkers = num;
    }

    public void setRddDataSetNumExamples(int i) {
        this.rddDataSetNumExamples = i;
    }

    public void setBatchSizePerWorker(int i) {
        this.batchSizePerWorker = i;
    }

    public void setAveragingFrequency(int i) {
        this.averagingFrequency = i;
    }

    public void setPrefetchNumBatches(int i) {
        this.prefetchNumBatches = i;
    }

    public void setStats(ParameterAveragingTrainingMasterStats.ParameterAveragingTrainingMasterStatsHelper parameterAveragingTrainingMasterStatsHelper) {
        this.stats = parameterAveragingTrainingMasterStatsHelper;
    }

    public void setIterationCount(int i) {
        this.iterationCount = i;
    }

    public void setRepartition(Repartition repartition) {
        this.repartition = repartition;
    }

    public void setRepartitionStrategy(RepartitionStrategy repartitionStrategy) {
        this.repartitionStrategy = repartitionStrategy;
    }

    public void setStorageLevel(StorageLevel storageLevel) {
        this.storageLevel = storageLevel;
    }

    public void setStorageLevelStreams(StorageLevel storageLevel) {
        this.storageLevelStreams = storageLevel;
    }

    public void setRddTrainingApproach(RDDTrainingApproach rDDTrainingApproach) {
        this.rddTrainingApproach = rDDTrainingApproach;
    }

    public void setExportDirectory(String str) {
        this.exportDirectory = str;
    }

    public void setRng(Random random) {
        this.rng = random;
    }

    public void setTrainingHookList(Collection<TrainingHook> collection) {
        this.trainingHookList = collection;
    }

    public void setLastExportedRDDId(int i) {
        this.lastExportedRDDId = i;
    }

    public void setLastRDDExportPath(String str) {
        this.lastRDDExportPath = str;
    }

    public void setStatsStorage(StatsStorageRouter statsStorageRouter) {
        this.statsStorage = statsStorageRouter;
    }

    public String toString() {
        return "ParameterAveragingTrainingMaster(saveUpdater=" + isSaveUpdater() + ", numWorkers=" + getNumWorkers() + ", rddDataSetNumExamples=" + getRddDataSetNumExamples() + ", batchSizePerWorker=" + getBatchSizePerWorker() + ", averagingFrequency=" + getAveragingFrequency() + ", prefetchNumBatches=" + getPrefetchNumBatches() + ", collectTrainingStats=" + isCollectTrainingStats() + ", stats=" + getStats() + ", iterationCount=" + getIterationCount() + ", repartition=" + getRepartition() + ", repartitionStrategy=" + getRepartitionStrategy() + ", storageLevel=" + getStorageLevel() + ", storageLevelStreams=" + getStorageLevelStreams() + ", rddTrainingApproach=" + getRddTrainingApproach() + ", exportDirectory=" + getExportDirectory() + ", rng=" + getRng() + ", trainingHookList=" + getTrainingHookList() + ", lastExportedRDDId=" + getLastExportedRDDId() + ", lastRDDExportPath=" + getLastRDDExportPath() + ", trainingMasterUID=" + getTrainingMasterUID() + ", listeners=" + getListeners() + ", statsStorage=" + getStatsStorage() + ")";
    }

    public boolean equals(Object obj) {
        if (obj == this) {
            return true;
        }
        if (!(obj instanceof ParameterAveragingTrainingMaster)) {
            return false;
        }
        ParameterAveragingTrainingMaster parameterAveragingTrainingMaster = (ParameterAveragingTrainingMaster) obj;
        if (!parameterAveragingTrainingMaster.canEqual(this) || isSaveUpdater() != parameterAveragingTrainingMaster.isSaveUpdater()) {
            return false;
        }
        Integer numWorkers = getNumWorkers();
        Integer numWorkers2 = parameterAveragingTrainingMaster.getNumWorkers();
        if (numWorkers == null) {
            if (numWorkers2 != null) {
                return false;
            }
        } else if (!numWorkers.equals(numWorkers2)) {
            return false;
        }
        if (getRddDataSetNumExamples() != parameterAveragingTrainingMaster.getRddDataSetNumExamples() || getBatchSizePerWorker() != parameterAveragingTrainingMaster.getBatchSizePerWorker() || getAveragingFrequency() != parameterAveragingTrainingMaster.getAveragingFrequency() || getPrefetchNumBatches() != parameterAveragingTrainingMaster.getPrefetchNumBatches() || isCollectTrainingStats() != parameterAveragingTrainingMaster.isCollectTrainingStats()) {
            return false;
        }
        Repartition repartition = getRepartition();
        Repartition repartition2 = parameterAveragingTrainingMaster.getRepartition();
        if (repartition == null) {
            if (repartition2 != null) {
                return false;
            }
        } else if (!repartition.equals(repartition2)) {
            return false;
        }
        RepartitionStrategy repartitionStrategy = getRepartitionStrategy();
        RepartitionStrategy repartitionStrategy2 = parameterAveragingTrainingMaster.getRepartitionStrategy();
        if (repartitionStrategy == null) {
            if (repartitionStrategy2 != null) {
                return false;
            }
        } else if (!repartitionStrategy.equals(repartitionStrategy2)) {
            return false;
        }
        StorageLevel storageLevel = getStorageLevel();
        StorageLevel storageLevel2 = parameterAveragingTrainingMaster.getStorageLevel();
        if (storageLevel == null) {
            if (storageLevel2 != null) {
                return false;
            }
        } else if (!storageLevel.equals(storageLevel2)) {
            return false;
        }
        StorageLevel storageLevelStreams = getStorageLevelStreams();
        StorageLevel storageLevelStreams2 = parameterAveragingTrainingMaster.getStorageLevelStreams();
        if (storageLevelStreams == null) {
            if (storageLevelStreams2 != null) {
                return false;
            }
        } else if (!storageLevelStreams.equals(storageLevelStreams2)) {
            return false;
        }
        RDDTrainingApproach rddTrainingApproach = getRddTrainingApproach();
        RDDTrainingApproach rddTrainingApproach2 = parameterAveragingTrainingMaster.getRddTrainingApproach();
        if (rddTrainingApproach == null) {
            if (rddTrainingApproach2 != null) {
                return false;
            }
        } else if (!rddTrainingApproach.equals(rddTrainingApproach2)) {
            return false;
        }
        String exportDirectory = getExportDirectory();
        String exportDirectory2 = parameterAveragingTrainingMaster.getExportDirectory();
        if (exportDirectory == null) {
            if (exportDirectory2 != null) {
                return false;
            }
        } else if (!exportDirectory.equals(exportDirectory2)) {
            return false;
        }
        Collection<TrainingHook> trainingHookList = getTrainingHookList();
        Collection<TrainingHook> trainingHookList2 = parameterAveragingTrainingMaster.getTrainingHookList();
        if (trainingHookList == null) {
            if (trainingHookList2 != null) {
                return false;
            }
        } else if (!trainingHookList.equals(trainingHookList2)) {
            return false;
        }
        StatsStorageRouter statsStorage = getStatsStorage();
        StatsStorageRouter statsStorage2 = parameterAveragingTrainingMaster.getStatsStorage();
        return statsStorage == null ? statsStorage2 == null : statsStorage.equals(statsStorage2);
    }

    protected boolean canEqual(Object obj) {
        return obj instanceof ParameterAveragingTrainingMaster;
    }

    public int hashCode() {
        int i = (1 * 59) + (isSaveUpdater() ? 79 : 97);
        Integer numWorkers = getNumWorkers();
        int hashCode = (((((((((((i * 59) + (numWorkers == null ? 43 : numWorkers.hashCode())) * 59) + getRddDataSetNumExamples()) * 59) + getBatchSizePerWorker()) * 59) + getAveragingFrequency()) * 59) + getPrefetchNumBatches()) * 59) + (isCollectTrainingStats() ? 79 : 97);
        Repartition repartition = getRepartition();
        int hashCode2 = (hashCode * 59) + (repartition == null ? 43 : repartition.hashCode());
        RepartitionStrategy repartitionStrategy = getRepartitionStrategy();
        int hashCode3 = (hashCode2 * 59) + (repartitionStrategy == null ? 43 : repartitionStrategy.hashCode());
        StorageLevel storageLevel = getStorageLevel();
        int hashCode4 = (hashCode3 * 59) + (storageLevel == null ? 43 : storageLevel.hashCode());
        StorageLevel storageLevelStreams = getStorageLevelStreams();
        int hashCode5 = (hashCode4 * 59) + (storageLevelStreams == null ? 43 : storageLevelStreams.hashCode());
        RDDTrainingApproach rddTrainingApproach = getRddTrainingApproach();
        int hashCode6 = (hashCode5 * 59) + (rddTrainingApproach == null ? 43 : rddTrainingApproach.hashCode());
        String exportDirectory = getExportDirectory();
        int hashCode7 = (hashCode6 * 59) + (exportDirectory == null ? 43 : exportDirectory.hashCode());
        Collection<TrainingHook> trainingHookList = getTrainingHookList();
        int hashCode8 = (hashCode7 * 59) + (trainingHookList == null ? 43 : trainingHookList.hashCode());
        StatsStorageRouter statsStorage = getStatsStorage();
        return (hashCode8 * 59) + (statsStorage == null ? 43 : statsStorage.hashCode());
    }
}
