机器学习笔记 - pytorch + unet + 数据科学碗竞赛 医学图像分割

一、数据集概述

1、数据科学碗竞赛

        数据集来自Kaggle网站的2018数据科学碗竞赛。 数据科学碗竞赛由 Booz Allen 和 Kaggle 主办的 Data Science Bowl 是全球首屈一指的社会公益竞赛数据科学。

        数据科学碗汇集了数据科学家、技术人员、领域专家和组织,以应对世界数据和技术的挑战。这是一个平台,人们可以通过它来驾驭他们的热情,释放他们的好奇心,并扩大他们的影响力,从而在全球范围内实现变革。

        为了展示比赛,Booz Allen 与Kaggle合作,Kaggle是领先的在线数据科学竞赛社区,在全球拥有超过 100 万会员。在 90 天的时间里,参与者,无论是单独还是团队合作,都可以访问独特的数据集,以开发应对特定挑战的算法。每年,比赛都会向顶级团队颁发现金奖励。

        2015 年,参与者检查了 100,000 多张由哈特菲尔德海洋科学中心提供的水下图像,以巨大的速度和规模评估海洋健康。超过 1,000 个团队参与,提交了 17,000 多个解决方案来应对挑战。获胜团队深海团队开发了一种分类算法,该算法比当前最先进的算法高出 10% 以上,在某些类别中实现了人类水平的表现。

        2016 年,他们将分析应用于心脏病学,改变了评估心脏功能的实践。尽管挑战显然比前一年更加复杂,但本次比赛收到了来自 1,100 多个团队的近 9,300 份参赛作品。事实上,获胜团队 Tencia Lee 和 Qi Liu是对冲基金交易员,而不是传统的数据科学家。美国国立卫生研究院正在进一步研究结果,并与医学和研究界分享成功的方法。

        2017年,近10,000名参与者致力于改进肺癌筛查技术,提交超过18,000个算法。初步结果表明,误报率降低了 10%,同时准确度比现有技术提高了 10%。由 Bonnie J. Addario 肺癌基金会和 DrivenData.org 赞助的后续竞赛正在进行中,以将 2017 年数据科学碗算法的进步从概念推广到临床。 

        2018年比赛网址,可以下载数据集。

2018 Data Science Bowl | KaggleFind the nuclei in divergent images to advance medical discoveryhttps://www.kaggle.com/competitions/data-science-bowl-2018/overview

2、数据集查看

        数据下载完成之后解压,会看到如下文件,这里只暂时关注stage1_train.zip和stage1_test.zip。

         其中一张样图。

样图

         样图对应的部分mask截图。

标记好的mask

二、参考代码

        参考代码下载

https://github.com/4uiiurz1/pytorch-nested-uneticon-default.png?t=M276https://github.com/4uiiurz1/pytorch-nested-unet

1、代码结构

        从py文件名称基本可以看出来文件的用处。

标题

 2、下载数据集到input/并解压缩

 3、图像预处理

        执行脚本进行图像预处理,主要是进行了mask的合并处理等。

python preprocess_dsb2018.py

         处理完成会得到如下图像和mask。

标题

 4、进行训练

python train.py --dataset dsb2018_96 --arch NestedUNet

        训练完成之后会再models文件夹下生成训练好的模型。 

 5、进行验证

        执行以下脚本进行验证,会再outputs文件夹下输出结果。

python val.py --name dsb2018_96_NestedUNet_woDS

        因为训练epochs次数不够,所以还有些模糊。

三、转ONNX

        调用pth_2onnx.py,进行onnx模型的转换。

def pth_2onnx():
    """
    pytorch 模型转换为onnx模型
    :return:
    """
    torch_model = torch.load('./models/dsb2018_96_NestedUNet_woDS/model.pth')

    config = vars(parse_args())
    model = archs.__dict__[config['arch']](config['num_classes'],
                                           config['input_channels'],
                                           config['deep_supervision'])
    model.load_state_dict(torch_model)
    batch_size = 1  # 批处理大小
    input_shape = (3, 96, 96)  # 输入数据

    # set the model to inference mode
    model.eval()
    print(model)
    x = torch.randn(batch_size, *input_shape)  # 生成张量
    export_onnx_file = "model.onnx"  # 目的ONNX文件名
    torch.onnx.export(model,
                      x,
                      export_onnx_file,
                      # 注意这个地方版本选择为11
                      opset_version=11)

        调用onnx模型进行推理

ort_session = ort.InferenceSession('model.onnx')
input_name = ort_session.get_inputs()[0].name

img = cv2.imread('ba0c9e776404370429e80.png')  # 02_test.tif')#demo.png
#img = cv2.resize(img, (96, 96))

nor = alb.Normalize()
img = nor.apply(image=img)
img = img.astype('float32') / 255
#img = img.transpose(2, 1, 0)
img = cv2.resize(img, (96, 96))

tensor = transforms.ToTensor()(img)
tensor = tensor.unsqueeze_(0)

ort_outs = ort_session.run(None, {input_name: tensor.cpu().numpy()})

img_out = ort_outs[0]
img_out = torch.from_numpy(img_out)
img_out = torch.sigmoid(img_out).cpu().numpy()

cv2.imwrite(os.path.join('result.png'), (img_out[0][0] * 255).astype('uint8'))

猜你喜欢

转载自blog.csdn.net/bashendixie5/article/details/123282703
今日推荐