package org.deeplearning4j.parallelism.main;

import com.beust.jcommander.JCommander;
import com.beust.jcommander.Parameter;
import com.beust.jcommander.ParameterException;
import java.io.File;
import org.deeplearning4j.api.storage.StatsStorageRouter;
import org.deeplearning4j.api.storage.impl.RemoteUIStatsStorageRouter;
import org.deeplearning4j.nn.api.Model;
import org.deeplearning4j.parallelism.ParallelWrapper;
import org.deeplearning4j.ui.stats.StatsListener;
import org.deeplearning4j.util.ModelGuesser;
import org.deeplearning4j.util.ModelSerializer;
import org.nd4j.linalg.dataset.api.iterator.DataSetIterator;
import org.nd4j.linalg.dataset.api.iterator.MultiDataSetIterator;

/* loaded from: input_file:org/deeplearning4j/parallelism/main/ParallelWrapperMain.class */
public class ParallelWrapperMain {

    @Parameter(names = {"--modelPath"}, description = "Path to the model", arity = 1, required = true)
    private String modelPath = null;

    @Parameter(names = {"--workers"}, description = "Number of workers", arity = 1)
    private int workers = 2;

    @Parameter(names = {"--prefetchSize"}, description = "The number of datasets to prefetch", arity = 1)
    private int prefetchSize = 16;

    @Parameter(names = {"--averagingFrequency"}, description = "The frequency for averaging parameters", arity = 1)
    private int averagingFrequency = 1;

    @Parameter(names = {"--reportScore"}, description = "The subcommand to run", arity = 1)
    private boolean reportScore = false;

    @Parameter(names = {"--averageUpdaters"}, description = "Whether to average updaters", arity = 1)
    private boolean averageUpdaters = true;

    @Parameter(names = {"--legacyAveraging"}, description = "Whether to use legacy averaging", arity = 1)
    private boolean legacyAveraging = true;

    @Parameter(names = {"--dataSetIteratorFactoryClazz"}, description = "The fully qualified class name of the multi data set iterator class to use.", arity = 1)
    private String dataSetIteratorFactoryClazz = null;

    @Parameter(names = {"--multiDataSetIteratorFactoryClazz"}, description = "The fully qualified class name of the multi data set iterator class to use.", arity = 1)
    private String multiDataSetIteratorFactoryClazz = null;

    @Parameter(names = {"--modelOutputPath"}, description = "The fully qualified class name of the multi data set iterator class to use.", arity = 1, required = true)
    private String modelOutputPath = null;

    @Parameter(names = {"--uiUrl"}, description = "The host:port of the ui to use (optional)", arity = 1)
    private String uiUrl = null;

    public static void main(String[] strArr) throws Exception {
        new ParallelWrapperMain().runMain(strArr);
    }

    public void runMain(String... strArr) throws Exception {
        JCommander jCommander = new JCommander(this);
        try {
            jCommander.parse(strArr);
        } catch (ParameterException e) {
            System.err.println(e.getMessage());
            jCommander.usage();
            try {
                Thread.sleep(500L);
            } catch (Exception e2) {
            }
            System.exit(1);
        }
        Model loadModelGuess = ModelGuesser.loadModelGuess(this.modelPath);
        ParallelWrapper build = new ParallelWrapper.Builder(loadModelGuess).prefetchBuffer(this.prefetchSize).workers(this.workers).averagingFrequency(this.averagingFrequency).averageUpdaters(this.averageUpdaters).reportScoreAfterAveraging(this.reportScore).useLegacyAveraging(this.legacyAveraging).build();
        if (this.dataSetIteratorFactoryClazz != null) {
            DataSetIterator create = ((DataSetIteratorProviderFactory) Class.forName(this.dataSetIteratorFactoryClazz).newInstance()).create();
            if (this.uiUrl != null) {
                build.setListeners((StatsStorageRouter) new RemoteUIStatsStorageRouter("http://" + this.uiUrl), new StatsListener((StatsStorageRouter) null));
            }
            build.fit(create);
            ModelSerializer.writeModel(loadModelGuess, new File(this.modelOutputPath), true);
            return;
        }
        if (this.multiDataSetIteratorFactoryClazz == null) {
            throw new IllegalStateException("Please provide a datasetiteraator or multi datasetiterator class");
        }
        MultiDataSetIterator create2 = ((MultiDataSetProviderFactory) Class.forName(this.multiDataSetIteratorFactoryClazz).newInstance()).create();
        if (this.uiUrl != null) {
            build.setListeners((StatsStorageRouter) new RemoteUIStatsStorageRouter("http://" + this.uiUrl), new StatsListener((StatsStorageRouter) null));
        }
        build.fit(create2);
        ModelSerializer.writeModel(loadModelGuess, new File(this.modelOutputPath), true);
    }
}
