Machine Learning Notes - Java Learning Framework Deeplearning4j First Experience

1. Overview of Deeplearning4j

        Deeplearning4j is a suite of tools for running deep learning on the JVM . It is the only framework that allows you to train models from java while interoperating with the python ecosystem through our mix of cpython bindings, model import support, and interop with other runtimes such as tensorflow-java and onnxruntime.

        Use cases include importing and retraining models (Pytorch, Tensorflow, Keras) models and deployment in JVM microservice environments, mobile devices, IoT, and Apache Spark. This is a great addition to your python environment to run models built in python, deploy to or package for other environments.

        All projects in the DL4J ecosystem support Windows, Linux, and macOS. Hardware support includes CUDA GPUs (10.0, 10.1, 10.2, except OSX), x86 CPUs (x86_64, avx2, avx512), ARM CPUs (arm, arm64, armhf), and PowerPC (ppc64le).

Two, Deeplearning4j module composition

        DL4J: High-level API for building multi-layer networks and computational graphs with various layers, including custom layers. Support for importing Keras models from h5, including tf.keras models (as of 1.0.0-M2), and distributed training on Apache Spark.

        ND4J: A general- purpose linear algebra library with over 500 math, linear algebra, and deep learning operations. ND4J is based on the highly optimized C++ code library LibND4J, providing CPU (AVX2/512) and GPU (CUDA) support and acceleration through libraries such as OpenBLAS, OneDNN (MKL-DNN), cuDNN, cuBLAS, etc.

        SameDiff: Part of the ND4J library, SameDiff is our automatic differentiation/deep learning framework. SameDiff uses a graph-based (define then run) approach, similar to the TensorFlow graph mode. Eager graph (TensorFlow 2.x eager/PyTorch) graph execution plan. SameDiff supports importing TensorFlow frozen model format .pb (protobuf) models. Plans to import ONNX, TensorFlow SavedModel and Keras models. Deeplearning4j also has full SameDiff support, making it easy to write custom layers and loss functions.

        DataVec: ETL for machine learning data in various formats and files (HDFS, Spark, images, video, audio, CSV, Excel, etc.)

        Arbiter: Hyperparameter Search Library

        LibND4J: The C++ library that underpins everything. For more information on how the JVM accesses native arrays and operations, see JavaCPP.

3. Configure Deeplearning4j in Maven

<?xml version="1.0" encoding="UTF-8"?>
<project xmlns="http://maven.apache.org/POM/4.0.0" xmlns:xsi="http://www.w3.org/2001/XMLSchema-instance"
	xsi:schemaLocation="http://maven.apache.org/POM/4.0.0 https://maven.apache.org/xsd/maven-4.0.0.xsd">
	<modelVersion>4.0.0</modelVersion>
	<parent>
		<groupId>org.springframework.boot</groupId>
		<artifactId>spring-boot-starter-parent</artifactId>
		<version>2.6.4</version>
		<relativePath/> <!-- lookup parent from repository -->
	</parent>
	<groupId>com.algorithm</groupId>
	<artifactId>demo</artifactId>
	<version>0.0.1-SNAPSHOT</version>
	<name>demo</name>
	<description>Demo project for Spring Boot</description>
	<properties>
		<dl4j-master.version>1.0.0-M2</dl4j-master.version>
		<java.version>1.8</java.version>
	</properties>
	<dependencies>
		<!-- deeplearning4j-core: contains main functionality and neural networks -->
		<dependency>
			<groupId>org.deeplearning4j</groupId>
			<artifactId>deeplearning4j-core</artifactId>
			<version>${dl4j-master.version}</version>
		</dependency>
		<dependency>
			<groupId>org.nd4j</groupId>
			<artifactId>nd4j-native</artifactId>
			<version>${dl4j-master.version}</version>
		</dependency>
		<dependency>
			<groupId>org.springframework.boot</groupId>
			<artifactId>spring-boot-starter</artifactId>
		</dependency>
		<dependency>
			<groupId>jfree</groupId>
			<artifactId>jfreechart</artifactId>
			<version>1.0.13</version>
		</dependency>
		<dependency>
			<groupId>org.springframework.boot</groupId>
			<artifactId>spring-boot-starter-test</artifactId>
			<scope>test</scope>
		</dependency>
	</dependencies>

	<build>
		<plugins>
			<plugin>
				<groupId>org.springframework.boot</groupId>
				<artifactId>spring-boot-maven-plugin</artifactId>
			</plugin>
		</plugins>
	</build>

</project>

Fourth, linear data classification example

1. Reference code

        LinearDataClassifier.java

package com.algorithm.demo.dl4jexamples;

import org.datavec.api.records.reader.RecordReader;
import org.datavec.api.records.reader.impl.csv.CSVRecordReader;
import org.datavec.api.split.FileSplit;
import org.deeplearning4j.datasets.datavec.RecordReaderDataSetIterator;
import org.deeplearning4j.examples.utils.DownloaderUtility;
import org.deeplearning4j.examples.utils.PlotUtil;
import org.deeplearning4j.nn.conf.MultiLayerConfiguration;
import org.deeplearning4j.nn.conf.NeuralNetConfiguration;
import org.deeplearning4j.nn.conf.layers.DenseLayer;
import org.deeplearning4j.nn.conf.layers.OutputLayer;
import org.deeplearning4j.nn.multilayer.MultiLayerNetwork;
import org.deeplearning4j.nn.weights.WeightInit;
import org.deeplearning4j.optimize.listeners.ScoreIterationListener;
import org.nd4j.evaluation.classification.Evaluation;
import org.nd4j.linalg.activations.Activation;
import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.linalg.dataset.DataSet;
import org.nd4j.linalg.dataset.api.iterator.DataSetIterator;
import org.nd4j.linalg.learning.config.Nesterovs;
import org.nd4j.linalg.lossfunctions.LossFunctions.LossFunction;

import java.io.File;
import java.util.concurrent.TimeUnit;

/**
 * "Linear" Data Classification Example
 * 
 * Based on the data from Jason Baldridge:
 * https://github.com/jasonbaldridge/try-tf/tree/master/simdata
 *
 * @author Josh Patterson
 * @author Alex Black (added plots)
 */
@SuppressWarnings("DuplicatedCode")
public class LinearDataClassifier {

    public static boolean visualize = true;
    public static String dataLocalPath;

    public static void main(String[] args) throws Exception {
        int seed = 123;
        double learningRate = 0.01;
        int batchSize = 50;
        int nEpochs = 30;

        int numInputs = 2;
        int numOutputs = 2;
        int numHiddenNodes = 20;

        dataLocalPath = DownloaderUtility.CLASSIFICATIONDATA.Download();

        //加载训练数据
        RecordReader rr = new CSVRecordReader();
        rr.initialize(new FileSplit(new File(dataLocalPath, "linear_data_train.csv")));
        DataSetIterator trainIter = new RecordReaderDataSetIterator(rr, batchSize, 0, 2);

        //加载验证数据
        RecordReader rrTest = new CSVRecordReader();
        rrTest.initialize(new FileSplit(new File(dataLocalPath, "linear_data_eval.csv")));
        DataSetIterator testIter = new RecordReaderDataSetIterator(rrTest, batchSize, 0, 2);

        //创建多层网络配置
        MultiLayerConfiguration conf = new NeuralNetConfiguration.Builder()
                .seed(seed)
                .weightInit(WeightInit.XAVIER)
                .updater(new Nesterovs(learningRate, 0.9))
                .list()
                .layer(new DenseLayer.Builder().nIn(numInputs).nOut(numHiddenNodes)
                        .activation(Activation.RELU)
                        .build())
                .layer(new OutputLayer.Builder(LossFunction.NEGATIVELOGLIKELIHOOD)
                        .activation(Activation.SOFTMAX)
                        .nIn(numHiddenNodes).nOut(numOutputs).build())
                .build();

        //网络初始化
        MultiLayerNetwork model = new MultiLayerNetwork(conf);
        model.init();
        model.setListeners(new ScoreIterationListener(10));  //Print score every 10 parameter updates
        //进行训练
        model.fit(trainIter, nEpochs);

        //进行验证
        System.out.println("Evaluate model....");
        Evaluation eval = new Evaluation(numOutputs);
        while (testIter.hasNext()) {
            DataSet t = testIter.next();
            INDArray features = t.getFeatures();
            INDArray labels = t.getLabels();
            INDArray predicted = model.output(features, false);
            eval.eval(labels, predicted);
        }
        //An alternate way to do the above loop
        //Evaluation evalResults = model.evaluate(testIter);

        //Print the evaluation statistics
        System.out.println(eval.stats());

        System.out.println("\n****************Example finished********************");
        //训练完成

        //以下代码仅用于绘制数据和预测可视化
        generateVisuals(model, trainIter, testIter);
    }

    public static void generateVisuals(MultiLayerNetwork model, DataSetIterator trainIter, DataSetIterator testIter) throws Exception {
        if (visualize) {
            double xMin = 0;
            double xMax = 1.0;
            double yMin = -0.2;
            double yMax = 0.8;
            int nPointsPerAxis = 100;

            //Generate x,y points that span the whole range of features
            INDArray allXYPoints = PlotUtil.generatePointsOnGraph(xMin, xMax, yMin, yMax, nPointsPerAxis);
            //Get train data and plot with predictions
            PlotUtil.plotTrainingData(model, trainIter, allXYPoints, nPointsPerAxis);
            TimeUnit.SECONDS.sleep(3);
            //Get test data, run the test data through the network to generate predictions, and plot those predictions:
            PlotUtil.plotTestData(model, testIter, allXYPoints, nPointsPerAxis);
        }
    }
}

        PlotUtil.java, the plotting tool

package com.algorithm.demo.dl4jexamples.utils;

import org.deeplearning4j.datasets.iterator.utilty.ListDataSetIterator;
import org.deeplearning4j.nn.multilayer.MultiLayerNetwork;
import org.jfree.chart.ChartPanel;
import org.jfree.chart.ChartUtilities;
import org.jfree.chart.JFreeChart;
import org.jfree.chart.axis.AxisLocation;
import org.jfree.chart.axis.NumberAxis;
import org.jfree.chart.block.BlockBorder;
import org.jfree.chart.plot.DatasetRenderingOrder;
import org.jfree.chart.plot.XYPlot;
import org.jfree.chart.renderer.GrayPaintScale;
import org.jfree.chart.renderer.PaintScale;
import org.jfree.chart.renderer.xy.XYBlockRenderer;
import org.jfree.chart.renderer.xy.XYLineAndShapeRenderer;
import org.jfree.chart.title.PaintScaleLegend;
import org.jfree.data.xy.*;
import org.jfree.ui.RectangleEdge;
import org.jfree.ui.RectangleInsets;
import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.linalg.api.ops.impl.indexaccum.custom.ArgMax;
import org.nd4j.linalg.dataset.DataSet;
import org.nd4j.linalg.dataset.api.iterator.DataSetIterator;
import org.nd4j.linalg.factory.Nd4j;

import javax.swing.*;
import java.awt.*;
import java.util.ArrayList;
import java.util.List;

/**
 * Simple plotting methods for the MLPClassifier quickstartexamples
 *
 * @author Alex Black
 */
public class PlotUtil {

    /**
     * Plot the training data. Assume 2d input, classification output
     *
     * @param model         Model to use to get predictions
     * @param trainIter     DataSet Iterator
     * @param backgroundIn  sets of x,y points in input space, plotted in the background
     * @param nDivisions    Number of points (per axis, for the backgroundIn/backgroundOut arrays)
     */
    public static void plotTrainingData(MultiLayerNetwork model, DataSetIterator trainIter, INDArray backgroundIn, int nDivisions) {
        double[] mins = backgroundIn.min(0).data().asDouble();
        double[] maxs = backgroundIn.max(0).data().asDouble();

        DataSet ds = allBatches(trainIter);
        INDArray backgroundOut = model.output(backgroundIn);

        XYZDataset backgroundData = createBackgroundData(backgroundIn, backgroundOut);
        JPanel panel = new ChartPanel(createChart(backgroundData, mins, maxs, nDivisions, createDataSetTrain(ds.getFeatures(), ds.getLabels())));

        JFrame f = new JFrame();
        f.add(panel);
        f.setDefaultCloseOperation(WindowConstants.EXIT_ON_CLOSE);
        f.pack();
        f.setTitle("Training Data");

        f.setVisible(true);
        f.setLocation(0, 0);
    }

    /**
     * Plot the training data. Assume 2d input, classification output
     *
     * @param model         Model to use to get predictions
     * @param testIter      Test Iterator
     * @param backgroundIn  sets of x,y points in input space, plotted in the background
     * @param nDivisions    Number of points (per axis, for the backgroundIn/backgroundOut arrays)
     */
    public static void plotTestData(MultiLayerNetwork model, DataSetIterator testIter, INDArray backgroundIn, int nDivisions) {

        double[] mins = backgroundIn.min(0).data().asDouble();
        double[] maxs = backgroundIn.max(0).data().asDouble();

        INDArray backgroundOut = model.output(backgroundIn);
        XYZDataset backgroundData = createBackgroundData(backgroundIn, backgroundOut);
        DataSet ds = allBatches(testIter);
        INDArray predicted = model.output(ds.getFeatures());
        JPanel panel = new ChartPanel(createChart(backgroundData, mins, maxs, nDivisions, createDataSetTest(ds.getFeatures(), ds.getLabels(), predicted)));

        JFrame f = new JFrame();
        f.add(panel);
        f.setDefaultCloseOperation(WindowConstants.EXIT_ON_CLOSE);
        f.pack();
        f.setTitle("Test Data");

        f.setVisible(true);
        f.setLocationRelativeTo(null);
        //f.setLocation(100,100);

    }


    /**
     * Create data for the background data set
     */
    private static XYZDataset createBackgroundData(INDArray backgroundIn, INDArray backgroundOut) {
        int nRows = backgroundIn.rows();
        double[] xValues = new double[nRows];
        double[] yValues = new double[nRows];
        double[] zValues = new double[nRows];
        for (int i = 0; i < nRows; i++) {
            xValues[i] = backgroundIn.getDouble(i, 0);
            yValues[i] = backgroundIn.getDouble(i, 1);
            zValues[i] = backgroundOut.getDouble(i, 0);

        }

        DefaultXYZDataset dataset = new DefaultXYZDataset();
        dataset.addSeries("Series 1",
                new double[][]{xValues, yValues, zValues});
        return dataset;
    }

    //Training data
    private static XYDataset createDataSetTrain(INDArray features, INDArray labels) {
        int nRows = features.rows();

        int nClasses = 2; // Binary classification using one output call end sigmoid.

        XYSeries[] series = new XYSeries[nClasses];
        for (int i = 0; i < series.length; i++) series[i] = new XYSeries("Class " + i);
        INDArray argMax = Nd4j.getExecutioner().exec(new ArgMax(new INDArray[]{labels},false,new int[]{1}))[0];
        for (int i = 0; i < nRows; i++) {
            int classIdx = (int) argMax.getDouble(i);
            series[classIdx].add(features.getDouble(i, 0), features.getDouble(i, 1));
        }

        XYSeriesCollection c = new XYSeriesCollection();
        for (XYSeries s : series) c.addSeries(s);
        return c;
    }

    //Test data
    private static XYDataset createDataSetTest(INDArray features, INDArray labels, INDArray predicted) {
        int nRows = features.rows();

        int nClasses = 2; // Binary classification using one output call end sigmoid.

        XYSeries[] series = new XYSeries[nClasses * nClasses];
        int[] series_index = new int[]{0, 3, 2, 1}; //little hack to make the charts look consistent.
        for (int i = 0; i < nClasses * nClasses; i++) {
            int trueClass = i / nClasses;
            int predClass = i % nClasses;
            String label = "actual=" + trueClass + ", pred=" + predClass;
            series[series_index[i]] = new XYSeries(label);
        }
        INDArray actualIdx = labels.argMax(1);
        INDArray predictedIdx = predicted.argMax(1);
        for (int i = 0; i < nRows; i++) {
            int classIdx = actualIdx.getInt(i);
            int predIdx = predictedIdx.getInt(i);
            int idx = series_index[classIdx * nClasses + predIdx];
            series[idx].add(features.getDouble(i, 0), features.getDouble(i, 1));
        }

        XYSeriesCollection c = new XYSeriesCollection();
        for (XYSeries s : series) c.addSeries(s);
        return c;
    }

    private static JFreeChart createChart(XYZDataset dataset, double[] mins, double[] maxs, int nPoints, XYDataset xyData) {
        NumberAxis xAxis = new NumberAxis("X");
        xAxis.setRange(mins[0], maxs[0]);


        NumberAxis yAxis = new NumberAxis("Y");
        yAxis.setRange(mins[1], maxs[1]);

        XYBlockRenderer renderer = new XYBlockRenderer();
        renderer.setBlockWidth((maxs[0] - mins[0]) / (nPoints - 1));
        renderer.setBlockHeight((maxs[1] - mins[1]) / (nPoints - 1));
        PaintScale scale = new GrayPaintScale(0, 1.0);
        renderer.setPaintScale(scale);
        XYPlot plot = new XYPlot(dataset, xAxis, yAxis, renderer);
        plot.setBackgroundPaint(Color.lightGray);
        plot.setDomainGridlinesVisible(false);
        plot.setRangeGridlinesVisible(false);
        plot.setAxisOffset(new RectangleInsets(5, 5, 5, 5));
        JFreeChart chart = new JFreeChart("", plot);
        chart.getXYPlot().getRenderer().setSeriesVisibleInLegend(0, false);


        NumberAxis scaleAxis = new NumberAxis("Probability (class 1)");
        scaleAxis.setAxisLinePaint(Color.white);
        scaleAxis.setTickMarkPaint(Color.white);
        scaleAxis.setTickLabelFont(new Font("Dialog", Font.PLAIN, 7));
        PaintScaleLegend legend = new PaintScaleLegend(new GrayPaintScale(),
                scaleAxis);
        legend.setStripOutlineVisible(false);
        legend.setSubdivisionCount(20);
        legend.setAxisLocation(AxisLocation.BOTTOM_OR_LEFT);
        legend.setAxisOffset(5.0);
        legend.setMargin(new RectangleInsets(5, 5, 5, 5));
        legend.setFrame(new BlockBorder(Color.red));
        legend.setPadding(new RectangleInsets(10, 10, 10, 10));
        legend.setStripWidth(10);
        legend.setPosition(RectangleEdge.LEFT);
        chart.addSubtitle(legend);

        ChartUtilities.applyCurrentTheme(chart);

        plot.setDataset(1, xyData);
        XYLineAndShapeRenderer renderer2 = new XYLineAndShapeRenderer();
        renderer2.setBaseLinesVisible(false);
        plot.setRenderer(1, renderer2);

        plot.setDatasetRenderingOrder(DatasetRenderingOrder.FORWARD);

        return chart;
    }

    public static INDArray generatePointsOnGraph(double xMin, double xMax, double yMin, double yMax, int nPointsPerAxis) {
        //generate all the x,y points
        double[][] evalPoints = new double[nPointsPerAxis * nPointsPerAxis][2];
        int count = 0;
        for (int i = 0; i < nPointsPerAxis; i++) {
            for (int j = 0; j < nPointsPerAxis; j++) {
                double x = i * (xMax - xMin) / (nPointsPerAxis - 1) + xMin;
                double y = j * (yMax - yMin) / (nPointsPerAxis - 1) + yMin;

                evalPoints[count][0] = x;
                evalPoints[count][1] = y;

                count++;
            }
        }

        return Nd4j.create(evalPoints);
    }

    /**
     * This is to collect all the data and return it as one minibatch. Obviously only for use here with small datasets
     * @param iter
     * @return
     */
    private static DataSet allBatches(DataSetIterator iter) {

        List<DataSet> fullSet = new ArrayList<>();
        iter.reset();
        while (iter.hasNext()) {
            List<DataSet> miniBatchList = iter.next().asList();
            fullSet.addAll(miniBatchList);
        }
        iter.reset();
        return new ListDataSetIterator<>(fullSet,fullSet.size()).next();
    }

}
        DownloaderUtility.java, download utility class
package com.algorithm.demo.dl4jexamples.utils;

import org.apache.commons.io.FilenameUtils;
import org.nd4j.common.resources.Downloader;

import java.io.File;
import java.net.URL;

/**
 * Given a base url and a zipped file name downloads contents to a specified directory under ~/dl4j-examples-data
 * Will check md5 sum of downloaded file
 * <p>
 *
 * Sample Usage with an instantiation DATAEXAMPLE(baseurl,"DataExamples.zip","data-dir",md5,size):
 *
 * DATAEXAMPLE.Download() & DATAEXAMPLE.Download(true)
 * Will download DataExamples.zip from baseurl/DataExamples.zip to a temp directory,
 * Unzip it to ~/dl4j-example-data/data-dir
 * Return the string "~/dl4j-example-data/data-dir/DataExamples"
 *
 * DATAEXAMPLE.Download(false)
 * will perform the same download and unzip as above
 * But returns the string "~/dl4j-example-data/data-dir" instead
 *
 *
 * @author susaneraly
 */
public enum DownloaderUtility {

    IRISDATA("IrisData.zip", "datavec-examples", "bb49e38bb91089634d7ef37ad8e430b8", "1KB"),
    ANIMALS("animals.zip", "dl4j-examples", "1976a1f2b61191d2906e4f615246d63e", "820KB"),
    ANOMALYSEQUENCEDATA("anomalysequencedata.zip", "dl4j-examples", "51bb7c50e265edec3a241a2d7cce0e73", "3MB"),
    CAPTCHAIMAGE("captchaImage.zip", "dl4j-examples", "1d159c9587fdbb1cbfd66f0d62380e61", "42MB"),
    CLASSIFICATIONDATA("classification.zip", "dl4j-examples", "dba31e5838fe15993579edbf1c60c355", "77KB"),
    DATAEXAMPLES("DataExamples.zip", "dl4j-examples", "e4de9c6f19aaae21fed45bfe2a730cbb", "2MB"),
    LOTTERYDATA("lottery.zip", "dl4j-examples", "1e54ac1210e39c948aa55417efee193a", "2MB"),
    NEWSDATA("NewsData.zip", "dl4j-examples", "0d08e902faabe6b8bfe5ecdd78af9f64", "21MB"),
    NLPDATA("nlp.zip", "dl4j-examples", "1ac7cd7ca08f13402f0e3b83e20c0512", "91MB"),
    PREDICTGENDERDATA("PredictGender.zip", "dl4j-examples", "42a3fec42afa798217e0b8687667257e", "3MB"),
    STYLETRANSFER("styletransfer.zip", "dl4j-examples", "b2b90834d667679d7ee3dfb1f40abe94", "3MB"),
    VIDEOEXAMPLE("video.zip","dl4j-examples", "56274eb6329a848dce3e20631abc6752", "8.5MB");

    private final String BASE_URL;
    private final String DATA_FOLDER;
    private final String ZIP_FILE;
    private final String MD5;
    private final String DATA_SIZE;
    private static final String AZURE_BLOB_URL = "https://dl4jdata.blob.core.windows.net/dl4j-examples";

    /**
     * For use with resources uploaded to Azure blob storage.
     *
     * @param zipFile    Name of zipfile. Should be a zip of a single directory with the same name
     * @param dataFolder The folder to extract to under ~/dl4j-examples-data
     * @param md5        of zipfile
     * @param dataSize   of zipfile
     */
    DownloaderUtility(String zipFile, String dataFolder, String md5, String dataSize) {
        this(AZURE_BLOB_URL + "/" + dataFolder, zipFile, dataFolder, md5, dataSize);
    }

    /**
     * Downloads a zip file from a base url to a specified directory under the user's home directory
     *
     * @param baseURL    URL of file
     * @param zipFile    Name of zipfile to download from baseURL i.e baseURL+"/"+zipFile gives full URL
     * @param dataFolder The folder to extract to under ~/dl4j-examples-data
     * @param md5        of zipfile
     * @param dataSize   of zipfile
     */
    DownloaderUtility(String baseURL, String zipFile, String dataFolder, String md5, String dataSize) {
        BASE_URL = baseURL;
        DATA_FOLDER = dataFolder;
        ZIP_FILE = zipFile;
        MD5 = md5;
        DATA_SIZE = dataSize;
    }

    public String Download() throws Exception {
        return Download(true);
    }

    public String Download(boolean returnSubFolder) throws Exception {
        String dataURL = BASE_URL + "/" + ZIP_FILE;
        String downloadPath = FilenameUtils.concat(System.getProperty("java.io.tmpdir"), ZIP_FILE);
        String extractDir = FilenameUtils.concat(System.getProperty("user.home"), "dl4j-examples-data/" + DATA_FOLDER);
        if (!new File(extractDir).exists())
            new File(extractDir).mkdirs();
        String dataPathLocal = extractDir;
        if (returnSubFolder) {
            String resourceName = ZIP_FILE.substring(0, ZIP_FILE.lastIndexOf(".zip"));
            dataPathLocal = FilenameUtils.concat(extractDir, resourceName);
        }
        int downloadRetries = 10;
        if (!new File(dataPathLocal).exists() || new File(dataPathLocal).list().length == 0) {
            System.out.println("_______________________________________________________________________");
            System.out.println("Downloading data (" + DATA_SIZE + ") and extracting to \n\t" + dataPathLocal);
            System.out.println("_______________________________________________________________________");
            Downloader.downloadAndExtract("files",
                    new URL(dataURL),
                    new File(downloadPath),
                    new File(extractDir),
                    MD5,
                    downloadRetries);
        } else {
            System.out.println("_______________________________________________________________________");
            System.out.println("Example data present in \n\t" + dataPathLocal);
            System.out.println("_______________________________________________________________________");
        }
        return dataPathLocal;
    }
}

2. Running results

         It is still very comfortable to use for those who are familiar with Java.

Guess you like

Origin blog.csdn.net/bashendixie5/article/details/123600031