最新JAVA的NLP工具DJL

一、简介

开源库以Java构建和部署深度学习、编写一次即可在任何地方运行。使用DJL开发模型并在您选择的引擎上运行。直观的API使用本机Java概念并抽象化了深度学习所涉及的复杂性。引入您自己的模型,或使用我们库中的最新模型在几分钟内进行部署。

二、开源地址:
https://github.com/awslabs/djl

三、例子或者用法

1、Single-shot object detection example
2、Train your first model
3、Image classification example
4、Transfer learning example
5、Train SSD model example
6、Multi-threaded inference example
7、Bert question and answer example
8、Instance segmentation example
9、Pose estimation example
10、Action recognition example
11、Multi-label dataset training example

四、官网地址:

https://djl.ai/

五、代码如下:
1、依赖:

<?xml version="1.0" encoding="UTF-8"?>
<project xmlns="http://maven.apache.org/POM/4.0.0" xmlns:xsi="http://www.w3.org/2001/XMLSchema-instance"
         xsi:schemaLocation="http://maven.apache.org/POM/4.0.0 https://maven.apache.org/xsd/maven-4.0.0.xsd">
    <modelVersion>4.0.0</modelVersion>
    <parent>
        <groupId>org.springframework.boot</groupId>
        <artifactId>spring-boot-starter-parent</artifactId>
        <version>2.2.5.RELEASE</version>
        <relativePath/> <!-- lookup parent from repository -->
    </parent>
    <groupId>com.citydo</groupId>
    <artifactId>bigdata</artifactId>
    <version>0.0.1-SNAPSHOT</version>
    <name>bigdata</name>
    <description>Demo project for Spring Boot</description>

    <properties>
        <java.version>1.8</java.version>
        <spring-cloud.version>Hoxton.SR1</spring-cloud.version>
    </properties>

    <dependencies>
        <dependency>
            <groupId>org.springframework.boot</groupId>
            <artifactId>spring-boot-starter-web</artifactId>
        </dependency>
        <dependency>
            <groupId>org.springframework.cloud</groupId>
            <artifactId>spring-cloud-starter-aws</artifactId>
        </dependency>
        <dependency>
            <groupId>ai.djl</groupId>
            <artifactId>examples</artifactId>
            <version>0.3.0</version>
        </dependency>
        <dependency>
            <groupId>ai.djl</groupId>
            <artifactId>api</artifactId>
            <version>0.3.0</version>
        </dependency>

        <dependency>
            <groupId>org.projectlombok</groupId>
            <artifactId>lombok</artifactId>
            <optional>true</optional>
        </dependency>

        <dependency>
            <groupId>ai.djl</groupId>
            <artifactId>basicdataset</artifactId>
            <version>0.3.0</version>
        </dependency>
        <dependency>
            <groupId>ai.djl</groupId>
            <artifactId>model-zoo</artifactId>
            <version>0.3.0</version>
        </dependency>
        <dependency>
            <groupId>ai.djl.mxnet</groupId>
            <artifactId>mxnet-model-zoo</artifactId>
            <version>0.3.0</version>
        </dependency>
        <dependency>
            <groupId>org.springframework.boot</groupId>
            <artifactId>spring-boot-starter-test</artifactId>
            <scope>test</scope>
            <exclusions>
                <exclusion>
                    <groupId>org.junit.vintage</groupId>
                    <artifactId>junit-vintage-engine</artifactId>
                </exclusion>
            </exclusions>
        </dependency>
    </dependencies>

    <dependencyManagement>
        <dependencies>
            <dependency>
                <groupId>org.springframework.cloud</groupId>
                <artifactId>spring-cloud-dependencies</artifactId>
                <version>${spring-cloud.version}</version>
                <type>pom</type>
                <scope>import</scope>
            </dependency>
        </dependencies>
    </dependencyManagement>

    <build>
        <plugins>
            <plugin>
                <groupId>org.springframework.boot</groupId>
                <artifactId>spring-boot-maven-plugin</artifactId>
            </plugin>
        </plugins>
    </build>

</project>

2、问答案例:

package com.citydo.bigdata.nlputils;

import ai.djl.Application;
import ai.djl.ModelException;
import ai.djl.inference.Predictor;
import ai.djl.mxnet.zoo.nlp.qa.QAInput;
import ai.djl.repository.zoo.Criteria;
import ai.djl.repository.zoo.ModelZoo;
import ai.djl.repository.zoo.ZooModel;
import ai.djl.training.util.ProgressBar;
import ai.djl.translate.TranslateException;
import java.io.IOException;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

/**
 * An example of inference using BertQA.
 *
 * <p>See:
 *
 * <ul>
 *   <li>the <a href="https://github.com/awslabs/djl/blob/master/jupyter/BERTQA.ipynb">jupyter
 *       demo</a> with more information about BERT.
 *   <li>the <a
 *       href="https://github.com/awslabs/djl/blob/master/examples/docs/BERT_question_and_answer.md">docs</a>
 *       for information about running this example.
 * </ul>
 */
public final class BertQaInference {

    private static final Logger logger = LoggerFactory.getLogger(BertQaInference.class);

    private BertQaInference() {}

    public static void main(String[] args) throws IOException, TranslateException, ModelException {
        String answer = BertQaInference.predict();
        logger.info("Answer: {}", answer);
    }

    public static String predict() throws IOException, TranslateException, ModelException {
        String question = "When did BBC Japan start broadcasting?";
        String paragraph =
                "BBC Japan was a general entertainment Channel.\n"
                        + "Which operated between December 2004 and April 2006.\n"
                        + "It ceased operations after its Japanese distributor folded.";

        QAInput input = new QAInput(question, paragraph, 384);
        logger.info("Paragraph: {}", input.getParagraph());
        logger.info("Question: {}", input.getQuestion());

        Criteria<QAInput, String> criteria =
                Criteria.builder()
                        .optApplication(Application.NLP.QUESTION_ANSWER)
                        .setTypes(QAInput.class, String.class)
                        .optFilter("backbone", "bert")
                        .optFilter("dataset", "book_corpus_wiki_en_uncased")
                        .optProgress(new ProgressBar())
                        .build();

        try (ZooModel<QAInput, String> model = ModelZoo.loadModel(criteria)) {
            try (Predictor<QAInput, String> predictor = model.newPredictor()) {
                return predictor.predict(input);
            }
        }
    }
}


3、训练模型

package com.citydo.bigdata.nlputils;

import ai.djl.Device;
import ai.djl.Model;
import ai.djl.basicdataset.CaptchaDataset;
import ai.djl.basicmodelzoo.cv.classification.ResNetV1;
import ai.djl.examples.training.util.Arguments;
import ai.djl.examples.training.util.ExampleTrainingResult;
import ai.djl.examples.training.util.TrainingUtils;
import ai.djl.metric.Metrics;
import ai.djl.ndarray.NDArray;
import ai.djl.ndarray.NDList;
import ai.djl.ndarray.types.Shape;
import ai.djl.nn.Block;
import ai.djl.nn.SequentialBlock;
import ai.djl.training.DefaultTrainingConfig;
import ai.djl.training.Trainer;
import ai.djl.training.dataset.Dataset;
import ai.djl.training.dataset.Dataset.Usage;
import ai.djl.training.dataset.RandomAccessDataset;
import ai.djl.training.evaluator.Accuracy;
import ai.djl.training.listener.TrainingListener;
import ai.djl.training.loss.SimpleCompositeLoss;
import ai.djl.training.loss.SoftmaxCrossEntropyLoss;
import ai.djl.training.util.ProgressBar;
import java.io.IOException;
import java.nio.file.Paths;

/**
 * An example of training a CAPTCHA solving model.
 *
 * <p>See this <a
 * href="https://github.com/awslabs/djl/blob/master/examples/docs/train_captcha.md">doc</a> for
 * information about this example.
 */
public final class TrainCaptcha {

    private TrainCaptcha() {}

    public static void main(String[] args) throws Exception{
        TrainCaptcha.runExample(args);
    }

    public static ExampleTrainingResult runExample(String[] args)
            throws Exception {
        Arguments arguments = Arguments.parseArgs(args);

        try (Model model = Model.newInstance()) {
            model.setBlock(getBlock());

            // get training and validation dataset
            RandomAccessDataset trainingSet = getDataset(Usage.TRAIN, arguments);
            RandomAccessDataset validateSet = getDataset(Usage.VALIDATION, arguments);

            // setup training configuration
            DefaultTrainingConfig config = setupTrainingConfig(arguments);

            ExampleTrainingResult result;
            try (Trainer trainer = model.newTrainer(config)) {
                trainer.setMetrics(new Metrics());

                Shape inputShape =
                        new Shape(1, 1, CaptchaDataset.IMAGE_HEIGHT, CaptchaDataset.IMAGE_WIDTH);

                // initialize trainer with proper input shape
                trainer.initialize(inputShape);

                TrainingUtils.fit(
                        trainer,
                        arguments.getEpoch(),
                        trainingSet,
                        validateSet,
                        arguments.getOutputDir(),
                        "captcha");

                result = new ExampleTrainingResult(trainer);
            }
            model.save(Paths.get(arguments.getOutputDir()), "captcha");
            return result;
        }
    }

    private static DefaultTrainingConfig setupTrainingConfig(Arguments arguments) {
        SimpleCompositeLoss loss = new SimpleCompositeLoss();
        for (int i = 0; i < CaptchaDataset.CAPTCHA_LENGTH; i++) {
            loss.addLoss(new SoftmaxCrossEntropyLoss("loss_digit_" + i), i);
        }

        DefaultTrainingConfig config =
                new DefaultTrainingConfig(loss)
                        .optDevices(Device.getDevices(arguments.getMaxGpus()))
                        .addTrainingListeners(
                                TrainingListener.Defaults.logging(arguments.getModelDir(),arguments.getBatchSize(),arguments.getEpoch(),arguments.getMaxGpus(),arguments.getOutputDir()));

        for (int i = 0; i < CaptchaDataset.CAPTCHA_LENGTH; i++) {
            config.addEvaluator(new Accuracy("acc_digit_" + i, i));
        }

        return config;
    }

    private static RandomAccessDataset getDataset(Dataset.Usage usage, Arguments arguments)
            throws IOException {
        CaptchaDataset dataset =
                CaptchaDataset.builder()
                        .optUsage(usage)
                        .setSampling(arguments.getBatchSize(), true)
                        .optMaxIteration(arguments.getMaxIterations())
                        .build();
        dataset.prepare(new ProgressBar());
        return dataset;
    }

    private static Block getBlock() {
        Block resnet =
                ResNetV1.builder()
                        .setNumLayers(50)
                        .setImageShape(
                                new Shape(
                                        1, CaptchaDataset.IMAGE_HEIGHT, CaptchaDataset.IMAGE_WIDTH))
                        .setOutSize(CaptchaDataset.CAPTCHA_OPTIONS * CaptchaDataset.CAPTCHA_LENGTH)
                        .build();

        return new SequentialBlock()
                .add(resnet)
                .add(
                        resnetOutputList -> {
                            NDArray resnetOutput = resnetOutputList.singletonOrThrow();
                            NDList splitOutput =
                                    resnetOutput
                                            .reshape(
                                                    -1,
                                                    CaptchaDataset.CAPTCHA_LENGTH,
                                                    CaptchaDataset.CAPTCHA_OPTIONS)
                                            .split(CaptchaDataset.CAPTCHA_LENGTH, 1);

                            NDList output = new NDList(CaptchaDataset.CAPTCHA_LENGTH);
                            for (NDArray outputDigit : splitOutput) {
                                output.add(outputDigit.squeeze(1));
                            }
                            return output;
                        });
    }
}

发布了226 篇原创文章 · 获赞 515 · 访问量 68万+

猜你喜欢

转载自blog.csdn.net/qq_32447301/article/details/104560771
今日推荐