Notas de aprendizaje automático - Marco de aprendizaje de Java Deeplearning4j Primera experiencia

1. Descripción general de Deeplearning4j

        Deeplearning4j es un conjunto de herramientas para ejecutar el aprendizaje profundo en la JVM . Es el único marco que le permite entrenar modelos desde Java mientras interactúa con el ecosistema de Python a través de nuestra combinación de enlaces de cpython, soporte de importación de modelos e interoperabilidad con otros tiempos de ejecución como tensorflow-java y onnxruntime.

        Los casos de uso incluyen modelos de importación y reentrenamiento (Pytorch, Tensorflow, Keras) e implementación en entornos de microservicios JVM, dispositivos móviles, IoT y Apache Spark. Esta es una excelente adición a su entorno de Python para ejecutar modelos creados en Python, implementar o empaquetar para otros entornos.

        Todos los proyectos del ecosistema DL4J son compatibles con Windows, Linux y macOS. El soporte de hardware incluye GPU CUDA (10.0, 10.1, 10.2, excepto OSX), CPU x86 (x86_64, avx2, avx512), CPU ARM (arm, arm64, armhf) y PowerPC (ppc64le).

Dos, composición del módulo Deeplearning4j

        DL4J: API de alto nivel para crear redes multicapa y gráficos computacionales con varias capas, incluidas capas personalizadas. Admite la importación de modelos Keras desde h5, incluidos los modelos tf.keras (a partir de 1.0.0-M2), y también admite el entrenamiento distribuido en Apache Spark.

        ND4J: una biblioteca de álgebra lineal de uso general con más de 500 operaciones matemáticas, de álgebra lineal y de aprendizaje profundo. ND4J se basa en la biblioteca de código C++ altamente optimizada LibND4J, que proporciona soporte y aceleración de CPU (AVX2/512) y GPU (CUDA) a través de bibliotecas como OpenBLAS, OneDNN (MKL-DNN), cuDNN, cuBLAS, etc.

        SameDiff: parte de la biblioteca ND4J, SameDiff es nuestro marco de diferenciación automática/aprendizaje profundo. SameDiff utiliza un enfoque basado en gráficos (definir y luego ejecutar), similar al modo gráfico de TensorFlow. Eager graph (TensorFlow 2.x entusiasta/PyTorch) plan de ejecución de gráficos. SameDiff admite la importación de modelos en formato .pb (protobuf) de modelo congelado de TensorFlow. Planes para importar modelos ONNX, TensorFlow SavedModel y Keras. Deeplearning4j también tiene compatibilidad completa con SameDiff, lo que facilita la escritura de capas personalizadas y funciones de pérdida.

        DataVec: ETL para datos de aprendizaje automático en varios formatos y archivos (HDFS, Spark, imágenes, video, audio, CSV, Excel, etc.)

        Árbitro: biblioteca de búsqueda de hiperparámetros

        LibND4J: la biblioteca de C++ que sustenta todo. Para obtener más información sobre cómo la JVM accede a las matrices y operaciones nativas, consulte JavaCPP.

3. Configurar Deeplearning4j en Maven

<?xml version="1.0" encoding="UTF-8"?>
<project xmlns="http://maven.apache.org/POM/4.0.0" xmlns:xsi="http://www.w3.org/2001/XMLSchema-instance"
	xsi:schemaLocation="http://maven.apache.org/POM/4.0.0 https://maven.apache.org/xsd/maven-4.0.0.xsd">
	<modelVersion>4.0.0</modelVersion>
	<parent>
		<groupId>org.springframework.boot</groupId>
		<artifactId>spring-boot-starter-parent</artifactId>
		<version>2.6.4</version>
		<relativePath/> <!-- lookup parent from repository -->
	</parent>
	<groupId>com.algorithm</groupId>
	<artifactId>demo</artifactId>
	<version>0.0.1-SNAPSHOT</version>
	<name>demo</name>
	<description>Demo project for Spring Boot</description>
	<properties>
		<dl4j-master.version>1.0.0-M2</dl4j-master.version>
		<java.version>1.8</java.version>
	</properties>
	<dependencies>
		<!-- deeplearning4j-core: contains main functionality and neural networks -->
		<dependency>
			<groupId>org.deeplearning4j</groupId>
			<artifactId>deeplearning4j-core</artifactId>
			<version>${dl4j-master.version}</version>
		</dependency>
		<dependency>
			<groupId>org.nd4j</groupId>
			<artifactId>nd4j-native</artifactId>
			<version>${dl4j-master.version}</version>
		</dependency>
		<dependency>
			<groupId>org.springframework.boot</groupId>
			<artifactId>spring-boot-starter</artifactId>
		</dependency>
		<dependency>
			<groupId>jfree</groupId>
			<artifactId>jfreechart</artifactId>
			<version>1.0.13</version>
		</dependency>
		<dependency>
			<groupId>org.springframework.boot</groupId>
			<artifactId>spring-boot-starter-test</artifactId>
			<scope>test</scope>
		</dependency>
	</dependencies>

	<build>
		<plugins>
			<plugin>
				<groupId>org.springframework.boot</groupId>
				<artifactId>spring-boot-maven-plugin</artifactId>
			</plugin>
		</plugins>
	</build>

</project>

Cuarto, ejemplo de clasificación de datos lineales

1. Código de referencia

        Clasificador de datos lineales.java

package com.algorithm.demo.dl4jexamples;

import org.datavec.api.records.reader.RecordReader;
import org.datavec.api.records.reader.impl.csv.CSVRecordReader;
import org.datavec.api.split.FileSplit;
import org.deeplearning4j.datasets.datavec.RecordReaderDataSetIterator;
import org.deeplearning4j.examples.utils.DownloaderUtility;
import org.deeplearning4j.examples.utils.PlotUtil;
import org.deeplearning4j.nn.conf.MultiLayerConfiguration;
import org.deeplearning4j.nn.conf.NeuralNetConfiguration;
import org.deeplearning4j.nn.conf.layers.DenseLayer;
import org.deeplearning4j.nn.conf.layers.OutputLayer;
import org.deeplearning4j.nn.multilayer.MultiLayerNetwork;
import org.deeplearning4j.nn.weights.WeightInit;
import org.deeplearning4j.optimize.listeners.ScoreIterationListener;
import org.nd4j.evaluation.classification.Evaluation;
import org.nd4j.linalg.activations.Activation;
import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.linalg.dataset.DataSet;
import org.nd4j.linalg.dataset.api.iterator.DataSetIterator;
import org.nd4j.linalg.learning.config.Nesterovs;
import org.nd4j.linalg.lossfunctions.LossFunctions.LossFunction;

import java.io.File;
import java.util.concurrent.TimeUnit;

/**
 * "Linear" Data Classification Example
 * 
 * Based on the data from Jason Baldridge:
 * https://github.com/jasonbaldridge/try-tf/tree/master/simdata
 *
 * @author Josh Patterson
 * @author Alex Black (added plots)
 */
@SuppressWarnings("DuplicatedCode")
public class LinearDataClassifier {

    public static boolean visualize = true;
    public static String dataLocalPath;

    public static void main(String[] args) throws Exception {
        int seed = 123;
        double learningRate = 0.01;
        int batchSize = 50;
        int nEpochs = 30;

        int numInputs = 2;
        int numOutputs = 2;
        int numHiddenNodes = 20;

        dataLocalPath = DownloaderUtility.CLASSIFICATIONDATA.Download();

        //加载训练数据
        RecordReader rr = new CSVRecordReader();
        rr.initialize(new FileSplit(new File(dataLocalPath, "linear_data_train.csv")));
        DataSetIterator trainIter = new RecordReaderDataSetIterator(rr, batchSize, 0, 2);

        //加载验证数据
        RecordReader rrTest = new CSVRecordReader();
        rrTest.initialize(new FileSplit(new File(dataLocalPath, "linear_data_eval.csv")));
        DataSetIterator testIter = new RecordReaderDataSetIterator(rrTest, batchSize, 0, 2);

        //创建多层网络配置
        MultiLayerConfiguration conf = new NeuralNetConfiguration.Builder()
                .seed(seed)
                .weightInit(WeightInit.XAVIER)
                .updater(new Nesterovs(learningRate, 0.9))
                .list()
                .layer(new DenseLayer.Builder().nIn(numInputs).nOut(numHiddenNodes)
                        .activation(Activation.RELU)
                        .build())
                .layer(new OutputLayer.Builder(LossFunction.NEGATIVELOGLIKELIHOOD)
                        .activation(Activation.SOFTMAX)
                        .nIn(numHiddenNodes).nOut(numOutputs).build())
                .build();

        //网络初始化
        MultiLayerNetwork model = new MultiLayerNetwork(conf);
        model.init();
        model.setListeners(new ScoreIterationListener(10));  //Print score every 10 parameter updates
        //进行训练
        model.fit(trainIter, nEpochs);

        //进行验证
        System.out.println("Evaluate model....");
        Evaluation eval = new Evaluation(numOutputs);
        while (testIter.hasNext()) {
            DataSet t = testIter.next();
            INDArray features = t.getFeatures();
            INDArray labels = t.getLabels();
            INDArray predicted = model.output(features, false);
            eval.eval(labels, predicted);
        }
        //An alternate way to do the above loop
        //Evaluation evalResults = model.evaluate(testIter);

        //Print the evaluation statistics
        System.out.println(eval.stats());

        System.out.println("\n****************Example finished********************");
        //训练完成

        //以下代码仅用于绘制数据和预测可视化
        generateVisuals(model, trainIter, testIter);
    }

    public static void generateVisuals(MultiLayerNetwork model, DataSetIterator trainIter, DataSetIterator testIter) throws Exception {
        if (visualize) {
            double xMin = 0;
            double xMax = 1.0;
            double yMin = -0.2;
            double yMax = 0.8;
            int nPointsPerAxis = 100;

            //Generate x,y points that span the whole range of features
            INDArray allXYPoints = PlotUtil.generatePointsOnGraph(xMin, xMax, yMin, yMax, nPointsPerAxis);
            //Get train data and plot with predictions
            PlotUtil.plotTrainingData(model, trainIter, allXYPoints, nPointsPerAxis);
            TimeUnit.SECONDS.sleep(3);
            //Get test data, run the test data through the network to generate predictions, and plot those predictions:
            PlotUtil.plotTestData(model, testIter, allXYPoints, nPointsPerAxis);
        }
    }
}

        PlotUtil.java, la herramienta de trazado

package com.algorithm.demo.dl4jexamples.utils;

import org.deeplearning4j.datasets.iterator.utilty.ListDataSetIterator;
import org.deeplearning4j.nn.multilayer.MultiLayerNetwork;
import org.jfree.chart.ChartPanel;
import org.jfree.chart.ChartUtilities;
import org.jfree.chart.JFreeChart;
import org.jfree.chart.axis.AxisLocation;
import org.jfree.chart.axis.NumberAxis;
import org.jfree.chart.block.BlockBorder;
import org.jfree.chart.plot.DatasetRenderingOrder;
import org.jfree.chart.plot.XYPlot;
import org.jfree.chart.renderer.GrayPaintScale;
import org.jfree.chart.renderer.PaintScale;
import org.jfree.chart.renderer.xy.XYBlockRenderer;
import org.jfree.chart.renderer.xy.XYLineAndShapeRenderer;
import org.jfree.chart.title.PaintScaleLegend;
import org.jfree.data.xy.*;
import org.jfree.ui.RectangleEdge;
import org.jfree.ui.RectangleInsets;
import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.linalg.api.ops.impl.indexaccum.custom.ArgMax;
import org.nd4j.linalg.dataset.DataSet;
import org.nd4j.linalg.dataset.api.iterator.DataSetIterator;
import org.nd4j.linalg.factory.Nd4j;

import javax.swing.*;
import java.awt.*;
import java.util.ArrayList;
import java.util.List;

/**
 * Simple plotting methods for the MLPClassifier quickstartexamples
 *
 * @author Alex Black
 */
public class PlotUtil {

    /**
     * Plot the training data. Assume 2d input, classification output
     *
     * @param model         Model to use to get predictions
     * @param trainIter     DataSet Iterator
     * @param backgroundIn  sets of x,y points in input space, plotted in the background
     * @param nDivisions    Number of points (per axis, for the backgroundIn/backgroundOut arrays)
     */
    public static void plotTrainingData(MultiLayerNetwork model, DataSetIterator trainIter, INDArray backgroundIn, int nDivisions) {
        double[] mins = backgroundIn.min(0).data().asDouble();
        double[] maxs = backgroundIn.max(0).data().asDouble();

        DataSet ds = allBatches(trainIter);
        INDArray backgroundOut = model.output(backgroundIn);

        XYZDataset backgroundData = createBackgroundData(backgroundIn, backgroundOut);
        JPanel panel = new ChartPanel(createChart(backgroundData, mins, maxs, nDivisions, createDataSetTrain(ds.getFeatures(), ds.getLabels())));

        JFrame f = new JFrame();
        f.add(panel);
        f.setDefaultCloseOperation(WindowConstants.EXIT_ON_CLOSE);
        f.pack();
        f.setTitle("Training Data");

        f.setVisible(true);
        f.setLocation(0, 0);
    }

    /**
     * Plot the training data. Assume 2d input, classification output
     *
     * @param model         Model to use to get predictions
     * @param testIter      Test Iterator
     * @param backgroundIn  sets of x,y points in input space, plotted in the background
     * @param nDivisions    Number of points (per axis, for the backgroundIn/backgroundOut arrays)
     */
    public static void plotTestData(MultiLayerNetwork model, DataSetIterator testIter, INDArray backgroundIn, int nDivisions) {

        double[] mins = backgroundIn.min(0).data().asDouble();
        double[] maxs = backgroundIn.max(0).data().asDouble();

        INDArray backgroundOut = model.output(backgroundIn);
        XYZDataset backgroundData = createBackgroundData(backgroundIn, backgroundOut);
        DataSet ds = allBatches(testIter);
        INDArray predicted = model.output(ds.getFeatures());
        JPanel panel = new ChartPanel(createChart(backgroundData, mins, maxs, nDivisions, createDataSetTest(ds.getFeatures(), ds.getLabels(), predicted)));

        JFrame f = new JFrame();
        f.add(panel);
        f.setDefaultCloseOperation(WindowConstants.EXIT_ON_CLOSE);
        f.pack();
        f.setTitle("Test Data");

        f.setVisible(true);
        f.setLocationRelativeTo(null);
        //f.setLocation(100,100);

    }


    /**
     * Create data for the background data set
     */
    private static XYZDataset createBackgroundData(INDArray backgroundIn, INDArray backgroundOut) {
        int nRows = backgroundIn.rows();
        double[] xValues = new double[nRows];
        double[] yValues = new double[nRows];
        double[] zValues = new double[nRows];
        for (int i = 0; i < nRows; i++) {
            xValues[i] = backgroundIn.getDouble(i, 0);
            yValues[i] = backgroundIn.getDouble(i, 1);
            zValues[i] = backgroundOut.getDouble(i, 0);

        }

        DefaultXYZDataset dataset = new DefaultXYZDataset();
        dataset.addSeries("Series 1",
                new double[][]{xValues, yValues, zValues});
        return dataset;
    }

    //Training data
    private static XYDataset createDataSetTrain(INDArray features, INDArray labels) {
        int nRows = features.rows();

        int nClasses = 2; // Binary classification using one output call end sigmoid.

        XYSeries[] series = new XYSeries[nClasses];
        for (int i = 0; i < series.length; i++) series[i] = new XYSeries("Class " + i);
        INDArray argMax = Nd4j.getExecutioner().exec(new ArgMax(new INDArray[]{labels},false,new int[]{1}))[0];
        for (int i = 0; i < nRows; i++) {
            int classIdx = (int) argMax.getDouble(i);
            series[classIdx].add(features.getDouble(i, 0), features.getDouble(i, 1));
        }

        XYSeriesCollection c = new XYSeriesCollection();
        for (XYSeries s : series) c.addSeries(s);
        return c;
    }

    //Test data
    private static XYDataset createDataSetTest(INDArray features, INDArray labels, INDArray predicted) {
        int nRows = features.rows();

        int nClasses = 2; // Binary classification using one output call end sigmoid.

        XYSeries[] series = new XYSeries[nClasses * nClasses];
        int[] series_index = new int[]{0, 3, 2, 1}; //little hack to make the charts look consistent.
        for (int i = 0; i < nClasses * nClasses; i++) {
            int trueClass = i / nClasses;
            int predClass = i % nClasses;
            String label = "actual=" + trueClass + ", pred=" + predClass;
            series[series_index[i]] = new XYSeries(label);
        }
        INDArray actualIdx = labels.argMax(1);
        INDArray predictedIdx = predicted.argMax(1);
        for (int i = 0; i < nRows; i++) {
            int classIdx = actualIdx.getInt(i);
            int predIdx = predictedIdx.getInt(i);
            int idx = series_index[classIdx * nClasses + predIdx];
            series[idx].add(features.getDouble(i, 0), features.getDouble(i, 1));
        }

        XYSeriesCollection c = new XYSeriesCollection();
        for (XYSeries s : series) c.addSeries(s);
        return c;
    }

    private static JFreeChart createChart(XYZDataset dataset, double[] mins, double[] maxs, int nPoints, XYDataset xyData) {
        NumberAxis xAxis = new NumberAxis("X");
        xAxis.setRange(mins[0], maxs[0]);


        NumberAxis yAxis = new NumberAxis("Y");
        yAxis.setRange(mins[1], maxs[1]);

        XYBlockRenderer renderer = new XYBlockRenderer();
        renderer.setBlockWidth((maxs[0] - mins[0]) / (nPoints - 1));
        renderer.setBlockHeight((maxs[1] - mins[1]) / (nPoints - 1));
        PaintScale scale = new GrayPaintScale(0, 1.0);
        renderer.setPaintScale(scale);
        XYPlot plot = new XYPlot(dataset, xAxis, yAxis, renderer);
        plot.setBackgroundPaint(Color.lightGray);
        plot.setDomainGridlinesVisible(false);
        plot.setRangeGridlinesVisible(false);
        plot.setAxisOffset(new RectangleInsets(5, 5, 5, 5));
        JFreeChart chart = new JFreeChart("", plot);
        chart.getXYPlot().getRenderer().setSeriesVisibleInLegend(0, false);


        NumberAxis scaleAxis = new NumberAxis("Probability (class 1)");
        scaleAxis.setAxisLinePaint(Color.white);
        scaleAxis.setTickMarkPaint(Color.white);
        scaleAxis.setTickLabelFont(new Font("Dialog", Font.PLAIN, 7));
        PaintScaleLegend legend = new PaintScaleLegend(new GrayPaintScale(),
                scaleAxis);
        legend.setStripOutlineVisible(false);
        legend.setSubdivisionCount(20);
        legend.setAxisLocation(AxisLocation.BOTTOM_OR_LEFT);
        legend.setAxisOffset(5.0);
        legend.setMargin(new RectangleInsets(5, 5, 5, 5));
        legend.setFrame(new BlockBorder(Color.red));
        legend.setPadding(new RectangleInsets(10, 10, 10, 10));
        legend.setStripWidth(10);
        legend.setPosition(RectangleEdge.LEFT);
        chart.addSubtitle(legend);

        ChartUtilities.applyCurrentTheme(chart);

        plot.setDataset(1, xyData);
        XYLineAndShapeRenderer renderer2 = new XYLineAndShapeRenderer();
        renderer2.setBaseLinesVisible(false);
        plot.setRenderer(1, renderer2);

        plot.setDatasetRenderingOrder(DatasetRenderingOrder.FORWARD);

        return chart;
    }

    public static INDArray generatePointsOnGraph(double xMin, double xMax, double yMin, double yMax, int nPointsPerAxis) {
        //generate all the x,y points
        double[][] evalPoints = new double[nPointsPerAxis * nPointsPerAxis][2];
        int count = 0;
        for (int i = 0; i < nPointsPerAxis; i++) {
            for (int j = 0; j < nPointsPerAxis; j++) {
                double x = i * (xMax - xMin) / (nPointsPerAxis - 1) + xMin;
                double y = j * (yMax - yMin) / (nPointsPerAxis - 1) + yMin;

                evalPoints[count][0] = x;
                evalPoints[count][1] = y;

                count++;
            }
        }

        return Nd4j.create(evalPoints);
    }

    /**
     * This is to collect all the data and return it as one minibatch. Obviously only for use here with small datasets
     * @param iter
     * @return
     */
    private static DataSet allBatches(DataSetIterator iter) {

        List<DataSet> fullSet = new ArrayList<>();
        iter.reset();
        while (iter.hasNext()) {
            List<DataSet> miniBatchList = iter.next().asList();
            fullSet.addAll(miniBatchList);
        }
        iter.reset();
        return new ListDataSetIterator<>(fullSet,fullSet.size()).next();
    }

}
        DownloaderUtility.java, clase de utilidad de descarga
package com.algorithm.demo.dl4jexamples.utils;

import org.apache.commons.io.FilenameUtils;
import org.nd4j.common.resources.Downloader;

import java.io.File;
import java.net.URL;

/**
 * Given a base url and a zipped file name downloads contents to a specified directory under ~/dl4j-examples-data
 * Will check md5 sum of downloaded file
 * <p>
 *
 * Sample Usage with an instantiation DATAEXAMPLE(baseurl,"DataExamples.zip","data-dir",md5,size):
 *
 * DATAEXAMPLE.Download() & DATAEXAMPLE.Download(true)
 * Will download DataExamples.zip from baseurl/DataExamples.zip to a temp directory,
 * Unzip it to ~/dl4j-example-data/data-dir
 * Return the string "~/dl4j-example-data/data-dir/DataExamples"
 *
 * DATAEXAMPLE.Download(false)
 * will perform the same download and unzip as above
 * But returns the string "~/dl4j-example-data/data-dir" instead
 *
 *
 * @author susaneraly
 */
public enum DownloaderUtility {

    IRISDATA("IrisData.zip", "datavec-examples", "bb49e38bb91089634d7ef37ad8e430b8", "1KB"),
    ANIMALS("animals.zip", "dl4j-examples", "1976a1f2b61191d2906e4f615246d63e", "820KB"),
    ANOMALYSEQUENCEDATA("anomalysequencedata.zip", "dl4j-examples", "51bb7c50e265edec3a241a2d7cce0e73", "3MB"),
    CAPTCHAIMAGE("captchaImage.zip", "dl4j-examples", "1d159c9587fdbb1cbfd66f0d62380e61", "42MB"),
    CLASSIFICATIONDATA("classification.zip", "dl4j-examples", "dba31e5838fe15993579edbf1c60c355", "77KB"),
    DATAEXAMPLES("DataExamples.zip", "dl4j-examples", "e4de9c6f19aaae21fed45bfe2a730cbb", "2MB"),
    LOTTERYDATA("lottery.zip", "dl4j-examples", "1e54ac1210e39c948aa55417efee193a", "2MB"),
    NEWSDATA("NewsData.zip", "dl4j-examples", "0d08e902faabe6b8bfe5ecdd78af9f64", "21MB"),
    NLPDATA("nlp.zip", "dl4j-examples", "1ac7cd7ca08f13402f0e3b83e20c0512", "91MB"),
    PREDICTGENDERDATA("PredictGender.zip", "dl4j-examples", "42a3fec42afa798217e0b8687667257e", "3MB"),
    STYLETRANSFER("styletransfer.zip", "dl4j-examples", "b2b90834d667679d7ee3dfb1f40abe94", "3MB"),
    VIDEOEXAMPLE("video.zip","dl4j-examples", "56274eb6329a848dce3e20631abc6752", "8.5MB");

    private final String BASE_URL;
    private final String DATA_FOLDER;
    private final String ZIP_FILE;
    private final String MD5;
    private final String DATA_SIZE;
    private static final String AZURE_BLOB_URL = "https://dl4jdata.blob.core.windows.net/dl4j-examples";

    /**
     * For use with resources uploaded to Azure blob storage.
     *
     * @param zipFile    Name of zipfile. Should be a zip of a single directory with the same name
     * @param dataFolder The folder to extract to under ~/dl4j-examples-data
     * @param md5        of zipfile
     * @param dataSize   of zipfile
     */
    DownloaderUtility(String zipFile, String dataFolder, String md5, String dataSize) {
        this(AZURE_BLOB_URL + "/" + dataFolder, zipFile, dataFolder, md5, dataSize);
    }

    /**
     * Downloads a zip file from a base url to a specified directory under the user's home directory
     *
     * @param baseURL    URL of file
     * @param zipFile    Name of zipfile to download from baseURL i.e baseURL+"/"+zipFile gives full URL
     * @param dataFolder The folder to extract to under ~/dl4j-examples-data
     * @param md5        of zipfile
     * @param dataSize   of zipfile
     */
    DownloaderUtility(String baseURL, String zipFile, String dataFolder, String md5, String dataSize) {
        BASE_URL = baseURL;
        DATA_FOLDER = dataFolder;
        ZIP_FILE = zipFile;
        MD5 = md5;
        DATA_SIZE = dataSize;
    }

    public String Download() throws Exception {
        return Download(true);
    }

    public String Download(boolean returnSubFolder) throws Exception {
        String dataURL = BASE_URL + "/" + ZIP_FILE;
        String downloadPath = FilenameUtils.concat(System.getProperty("java.io.tmpdir"), ZIP_FILE);
        String extractDir = FilenameUtils.concat(System.getProperty("user.home"), "dl4j-examples-data/" + DATA_FOLDER);
        if (!new File(extractDir).exists())
            new File(extractDir).mkdirs();
        String dataPathLocal = extractDir;
        if (returnSubFolder) {
            String resourceName = ZIP_FILE.substring(0, ZIP_FILE.lastIndexOf(".zip"));
            dataPathLocal = FilenameUtils.concat(extractDir, resourceName);
        }
        int downloadRetries = 10;
        if (!new File(dataPathLocal).exists() || new File(dataPathLocal).list().length == 0) {
            System.out.println("_______________________________________________________________________");
            System.out.println("Downloading data (" + DATA_SIZE + ") and extracting to \n\t" + dataPathLocal);
            System.out.println("_______________________________________________________________________");
            Downloader.downloadAndExtract("files",
                    new URL(dataURL),
                    new File(downloadPath),
                    new File(extractDir),
                    MD5,
                    downloadRetries);
        } else {
            System.out.println("_______________________________________________________________________");
            System.out.println("Example data present in \n\t" + dataPathLocal);
            System.out.println("_______________________________________________________________________");
        }
        return dataPathLocal;
    }
}

2. Ejecución de resultados

         Todavía es muy cómodo de usar para aquellos que están familiarizados con Java.

Supongo que te gusta

Origin blog.csdn.net/bashendixie5/article/details/123600031
Recomendado
Clasificación