文章目录
DJL库
DJL(Deep Java Library) 是一个使用Java API简化模型训练、测试、部署和使用深度学习模型进行推理的开源库深度学习工具包,开源的许可协议是Apache-2.0。
对于Java开发者而言,可以在Java中开发及应用原生的机器学习和深度学习模型,同时简化了深度学习开发的难度。
通过DJL提供的直观的、高级的API,Java开发人员可以训练自己的模型,或者利用数据科学家用Python预先训练好的模型来进行推理。
Spring Boot 微服务集成 DJL
这里选择 Kotlin + Gradle + Spring Boot 搭建项目,引入 sprint-boot-starter-web
依赖
引入 djl-spring-boot-starter 依赖
dependencies {
implementation("org.springframework.boot:spring-boot-starter-web")
implementation("ai.djl.spring:djl-spring-boot-starter-pytorch-auto:0.15")
implementation("ai.djl.spring:djl-spring-boot-starter-autoconfigure:0.15")
implementation("net.java.dev.jna:jna:5.11.0")
implementation("org.slf4j:slf4j-api:1.7.36")
implementation("org.jetbrains.kotlin:kotlin-reflect")
implementation("org.jetbrains.kotlin:kotlin-stdlib-jdk8")
testImplementation("org.springframework.boot:spring-boot-starter-test")
}
下载 resnet18 模型
DownloadUtils.download(
"https://djl-ai.s3.amazonaws.com/mlrepo/model/cv/image_classification/ai/djl/pytorch/resnet/0.0.1/traced_resnet18.pt.gz",
"src/main/resources/models/resnet18/resnet18.pt", ProgressBar())
DownloadUtils.download(
"https://djl-ai.s3.amazonaws.com/mlrepo/model/cv/image_classification/ai/djl/pytorch/synset.txt",
"src/main/resources/models/resnet18/synset.txt", ProgressBar())
模型文件会下载到 src/main/resources/models/resnet18
路径下
定义预测工具类
// util/PredictUtil.kt
@Component
class PredictUtil {
private fun createPipeline(): Pipeline {
val pipeline = Pipeline()
pipeline.add(Resize(224, 224))
.add(ToTensor())
.add(Normalize(
floatArrayOf(0.485f, 0.456f, 0.406f),
floatArrayOf(0.229f, 0.224f, 0.225f)))
return pipeline
}
private fun createTranslator(pipeline: Pipeline): Translator<Image, Classifications> {
return ImageClassificationTranslator.builder()
.setPipeline(pipeline)
.optApplySoftmax(true)
.build()
}
private fun loadModel(): ZooModel<Image, Classifications> {
val translator = createTranslator(createPipeline())
val criteria: Criteria<Image, Classifications> = Criteria.builder()
.setTypes(Image::class.java, Classifications::class.java)
.optModelPath(Paths.get("src/main/resources/models/resnet18"))
.optTranslator(translator)
.optProgress(ProgressBar()).build()
return criteria.loadModel()
}
fun predict(file: File): String {
val model: ZooModel<Image, Classifications> = loadModel()
val image = ImageFactory.getInstance().fromInputStream(FileInputStream(file))
val predictor: Predictor<Image, Classifications> = model.newPredictor()
val classifications = predictor.predict(image)
println(classifications)
return classifications.toString()
}
}
写预测接口
// controller/PredictController.kt
@RestController
@RequestMapping(path = ["/predict"])
class PredictController {
@Resource
private val predictUtil = PredictUtil()
@RequestMapping(method = [RequestMethod.POST], path = ["/"])
fun predict(@RequestBody file: MultipartFile): String {
// DownloadUtils.download(
// "https://djl-ai.s3.amazonaws.com/mlrepo/model/cv/image_classification/ai/djl/pytorch/resnet/0.0.1/traced_resnet18.pt.gz",
// "src/main/resources/models/resnet18/resnet18.pt", ProgressBar())
// DownloadUtils.download(
// "https://djl-ai.s3.amazonaws.com/mlrepo/model/cv/image_classification/ai/djl/pytorch/synset.txt",
// "src/main/resources/models/resnet18/synset.txt", ProgressBar())
val tmp = File.createTempFile("temp", null)
file.transferTo(tmp);
return predictUtil.predict(tmp);
}
}
使用 Postman 测试
向 http://127.0.0.1:8080/predict/
发送 POST
请求
成功运行