【pytorch】图像数据预处理

本文是记录一些在深度学习中的预处理的一些语法和函数

torchvision.transforms的图像变换

[PyTorch 学习笔记] 2.3 二十二种 transforms 图片数据预处理方法 - 知乎

TORCHVISION.TRANSFORMS的图像预处理_阿巫兮兮的博客-CSDN博客

PyTorch 09:transforms 图像变换、方法操作及自定义方法 - YEY 的博客 | YEY Blog

2D、3D中心裁剪:

import random
def random_crop_2d(img, label, crop_size):
    random_x_max = img.shape[0] - crop_size[0]
    random_y_max = img.shape[1] - crop_size[1]

    if random_x_max < 0 or random_y_max < 0:
        return None

    x_random = random.randint(0, random_x_max)
    y_random = random.randint(0, random_y_max)

    crop_img = img[x_random:x_random + crop_size[0], y_random:y_random + crop_size[1]]
    crop_label = label[x_random:x_random + crop_size[0], y_random:y_random + crop_size[1]]

    return crop_img, crop_label


def random_crop_3d(img, label, crop_size):
    random_x_max = img.shape[0] - crop_size[0]
    random_y_max = img.shape[1] - crop_size[1]
    random_z_max = img.shape[2] - crop_size[2]

    if random_x_max < 0 or random_y_max < 0 or random_z_max < 0:
        return None

    x_random = random.randint(0, random_x_max)
    y_random = random.randint(0, random_y_max)
    z_random = random.randint(0, random_z_max)

    crop_img = img[x_random:x_random + crop_size[0], y_random:y_random + crop_size[1], z_random:z_random + crop_size[2]]
    crop_label = label[x_random:x_random + crop_size[0], y_random:y_random + crop_size[1],
                 z_random:z_random + crop_size[2]]

    return crop_img, crop_label

class RandomCrop_3d:
    def __init__(self, slices):
        self.slices = slices

    def _get_range(self, slices, crop_slices):
        if slices < crop_slices:
            start = 0
        else:
            start = random.randint(0, slices - crop_slices)
        end = start + crop_slices
        if end > slices:
            end = slices
        return start, end

    def __call__(self, img, mask):

        ss, es = self._get_range(mask.size(0), self.slices)

        tmp_img = torch.zeros((img.size(0), self.slices, img.size(2),img.size(3)))
        tmp_mask = torch.zeros((mask.size(0), self.slices, mask.size(2),mask.size(3)))

        tmp_img[:, :es - ss] = img[:, ss:es]
        tmp_mask[:, :es - ss] = mask[:, ss:es]
        return tmp_img, tmp_mask

transforms的一些图像处理算法:

"""
This part is based on the dataset class implemented by pytorch, 
including train_dataset and test_dataset, as well as data augmentation
"""
from torch.utils.data import Dataset
import torch
import numpy as np
import random
import torch.nn.functional as F
from torchvision import transforms
from torchvision.transforms.functional import normalize

#----------------------data augment-------------------------------------------
class Resize:
    def __init__(self, scale):
        # self.shape = [shape, shape, shape] if isinstance(shape, int) else shape
        self.scale = scale

    def __call__(self, img, mask):
        img, mask = img.unsqueeze(0), mask.unsqueeze(0).float()
        img = F.interpolate(img, scale_factor=(1,self.scale,self.scale),mode='trilinear', align_corners=False, recompute_scale_factor=True)
        mask = F.interpolate(mask, scale_factor=(1,self.scale,self.scale), mode="nearest", recompute_scale_factor=True)
        return img[0], mask[0]

class RandomResize:
    def __init__(self,s_rank, w_rank,h_rank):
        self.w_rank = w_rank
        self.h_rank = h_rank
        self.s_rank = s_rank

    def __call__(self, img, mask):
        random_w = random.randint(self.w_rank[0],self.w_rank[1])
        random_h = random.randint(self.h_rank[0],self.h_rank[1])
        random_s = random.randint(self.s_rank[0],self.s_rank[1])
        self.shape = [random_s,random_h,random_w]
        img, mask = img.unsqueeze(0), mask.unsqueeze(0).float()
        img = F.interpolate(img, size=self.shape,mode='trilinear', align_corners=False)
        mask = F.interpolate(mask, size=self.shape, mode="nearest")
        return img[0], mask[0].long()

class RandomCrop:
    def __init__(self, slices):
        self.slices =  slices

    def _get_range(self, slices, crop_slices):
        if slices < crop_slices:
            start = 0
        else:
            start = random.randint(0, slices - crop_slices)
        end = start + crop_slices
        if end > slices:
            end = slices
        return start, end

    def __call__(self, img, mask):

        ss, es = self._get_range(mask.size(1), self.slices)
        
        # print(self.shape, img.shape, mask.shape)
        tmp_img = torch.zeros((img.size(0), self.slices, img.size(2), img.size(3)))
        tmp_mask = torch.zeros((mask.size(0), self.slices, mask.size(2), mask.size(3)))
        tmp_img[:,:es-ss] = img[:,ss:es]
        tmp_mask[:,:es-ss] = mask[:,ss:es]
        return tmp_img, tmp_mask

class RandomFlip_LR:
    def __init__(self, prob=0.5):
        self.prob = prob

    def _flip(self, img, prob):
        if prob[0] <= self.prob:
            img = img.flip(2)
        return img

    def __call__(self, img, mask):
        prob = (random.uniform(0, 1), random.uniform(0, 1))
        return self._flip(img, prob), self._flip(mask, prob)

class RandomFlip_UD:
    def __init__(self, prob=0.5):
        self.prob = prob

    def _flip(self, img, prob):
        if prob[1] <= self.prob:
            img = img.flip(3)
        return img

    def __call__(self, img, mask):
        prob = (random.uniform(0, 1), random.uniform(0, 1))
        return self._flip(img, prob), self._flip(mask, prob)

class RandomRotate:
    def __init__(self, max_cnt=3):
        self.max_cnt = max_cnt

    def _rotate(self, img, cnt):
        img = torch.rot90(img,cnt,[1,2])
        return img

    def __call__(self, img, mask):
        cnt = random.randint(0,self.max_cnt)
        return self._rotate(img, cnt), self._rotate(mask, cnt)


class Center_Crop:
    def __init__(self, base, max_size):
        self.base = base  # base默认取16,因为4次下采样后为1
        self.max_size = max_size 
        if self.max_size%self.base:
            self.max_size = self.max_size - self.max_size%self.base # max_size为限制最大采样slices数,防止显存溢出,同时也应为16的倍数
    def __call__(self, img , label):
        if img.size(1) < self.base:
            return None
        slice_num = img.size(1) - img.size(1) % self.base
        slice_num = min(self.max_size, slice_num)

        left = img.size(1)//2 - slice_num//2
        right =  img.size(1)//2 + slice_num//2

        crop_img = img[:,left:right]
        crop_label = label[:,left:right]
        return crop_img, crop_label

class ToTensor:
    def __init__(self):
        self.to_tensor = transforms.ToTensor()

    def __call__(self, img, mask):
        img = self.to_tensor(img)
        mask = torch.from_numpy(np.array(mask))
        return img, mask[None]


class Normalize:
    def __init__(self, mean, std):
        self.mean = mean
        self.std = std

    def __call__(self, img, mask):
        return normalize(img, self.mean, self.std, False), mask


class Compose:
    def __init__(self, transforms):
        self.transforms = transforms

    def __call__(self, img, mask):
        for t in self.transforms:
            img, mask = t(img, mask)
        return img, mask

把数据集分成train + Val 参数自己定义

import torch
import os
import shutil
from random import sample

root = './data/head'
Origin = 'images'
Segmen = 'labels'
n = 0.8

# 数据分类处理
## 提取文件夹内的名字
data_file = os.listdir(f'{root}/{Origin}')
segm_file = os.listdir(f'{root}/{Segmen}')

train_size = int(len(data_file) * n)
train_img_url = sample(data_file,train_size)
val_img_url = list(set(data_file)^set(train_img_url))       ## 求差集

## 移动图片
for i in range(len(train_img_url)):          ## 移动train
    if not os.path.exists(f'{root}/train'):
        os.mkdir(f'{root}/train')            ## 创建train
    if not os.path.exists(f'{root}/train/{Origin}'):
        os.mkdir(f'{root}/train/{Origin}')  ## 创建源图文件夹
    if not os.path.exists(f'{root}/train/{Segmen}'):
        os.mkdir(f'{root}/train/{Segmen}')  ## 创建分割文件夹

    ## 转移源图
    image = os.path.join(f'{root}/{Origin}', train_img_url[i])
    image = image.replace('\\', '/')
    image_class = os.path.join(f'{root}/train/{Origin}', train_img_url[i])
    image_class = image_class.replace('\\', '/')
    shutil.copy(image, f'{root}/train/{Origin}')

    ## 转移分割图
    seg = os.path.join(f'{root}/{Segmen}', train_img_url[i].replace('jpg', 'png'))
    seg = seg.replace('\\', '/')
    seg_class = os.path.join(f'{root}/train/{Segmen}', train_img_url[i].replace('jpg', 'png'))
    seg_class = seg_class.replace('\\', '/')
    shutil.copy(seg, f'{root}/train/{Segmen}')


for i in range(len(val_img_url)):          ## 移动Val
    if not os.path.exists(f'{root}/Val'):
        os.mkdir(f'{root}/Val')            ## 创建Val
    if not os.path.exists(f'{root}/Val/{Origin}'):
        os.mkdir(f'{root}/Val/{Origin}')  ## 创建源图文件夹
    if not os.path.exists(f'{root}/Val/{Segmen}'):
        os.mkdir(f'{root}/Val/{Segmen}')  ## 创建分割文件夹

    ## 转移源图
    image = os.path.join(f'{root}/{Origin}', val_img_url[i])
    image = image.replace('\\', '/')
    image_class = os.path.join(f'{root}/Val/{Origin}', val_img_url[i])
    image_class = image_class.replace('\\', '/')
    shutil.copy(image, f'{root}/Val/{Origin}')

    ## 转移分割图
    seg = os.path.join(f'{root}/{Segmen}', val_img_url[i].replace('jpg', 'png'))
    seg = seg.replace('\\', '/')
    seg_class = os.path.join(f'{root}/Val/{Segmen}', val_img_url[i].replace('jpg', 'png'))
    seg_class = seg_class.replace('\\', '/')
    shutil.copy(seg, f'{root}/Val/{Segmen}')


f = open(os.path.join(root, 'train_path_list.txt'), 'w')
for name in train_img_url:
    ct_path = os.path.join(f'{root}/{Origin}', name)
    f.write(ct_path + "\n")
f.close()
f = open(os.path.join(root, 'val_path_list.txt'), 'w')
for name in val_img_url:
    ct_path = os.path.join(f'{root}/{Origin}', name)
    f.write(ct_path + "\n")
f.close()

实现Nii格式的CT医学图像三维重建:

import vtk

def showNiiVtk3D(niipath):
    render = vtk.vtkRenderer()  # 搭建舞台,实例化对象render
    renWin = vtk.vtkRenderWindow()  # 实例化窗口对象
    ir = vtk.vtkRenderWindowInteractor()  # 定义一个为鼠标/键/时间事件提供独立于平台的交互机制
    ir.SetRenderWindow(renWin)  # 将ir机制关联到设置好的renWin
    renWin.AddRenderer(render)  # 将舞台render加入到renWin窗口中
    style = vtk.vtkInteractorStyleTrackballCamera()  # 定义对象,当移动摄像头、按键、屏幕上的所有内容都会动起来。
    ir.SetInteractorStyle(style)  # 将style对象关联到鼠标/按键机制
    reader = vtk.vtkNIFTIImageReader()  # 定义一个文件读取对象
    reader.SetFileName(niipath)  # 读取文件

    contourfilter = vtk.vtkContourFilter()  # 过滤器vtkContourFilter用于从数据中抽取一系列等值面。
    contourfilter.SetInputConnection(reader.GetOutputPort())
    contourfilter.SetValue(0, 250)

    smooth = vtk.vtkSmoothPolyDataFilter()  # 光滑图像
    smooth.SetInputConnection(contourfilter.GetOutputPort())
    smooth.SetNumberOfIterations(300)

    normal = vtk.vtkPolyDataNormals()  # 法线
    normal.SetInputConnection(smooth.GetOutputPort())
    normal.SetFeatureAngle(60)

    conMapper = vtk.vtkPolyDataMapper()  # 实例化映射器conMapper
    conMapper.SetInputConnection(normal.GetOutputPort())  # 源数据输入给映射器输入
    conMapper.ScalarVisibilityOff()

    conActor = vtk.vtkActor()  # 创建演员对象
    conActor.SetMapper(conMapper)  # 为演员指定mapper进行映射
    conActor.GetProperty().SetColor(1, 0, 0)  # 设置演员颜色为红色
    render.AddActor(conActor)  # 将演员加入到场景

    boxFilter = vtk.vtkOutlineFilter()
    boxFilter.SetInputConnection(reader.GetOutputPort())

    boxMapper = vtk.vtkPolyDataMapper()
    boxMapper.SetInputConnection(boxFilter.GetOutputPort())

    boxActor = vtk.vtkActor()
    boxActor.SetMapper(boxMapper)
    boxActor.GetProperty().SetColor(0, 1, 0)

    camera = vtk.vtkCamera()
    camera.SetViewUp(0, 0, -1)
    camera.SetPosition(0, 1, 0)
    camera.SetFocalPoint(0, 0, 0)
    camera.ComputeViewPlaneNormal()
    camera.Dolly(1.5)

    render.SetActiveCamera(camera)
    render.ResetCamera()
    ir.Initialize()
    ir.Start()

showNiiVtk3D('F:\\data\\diao_0.nii')

另一版代码(不推荐):

import vtk

reader = vtk.vtkNIFTIImageReader()
reader.SetFileName('./fixed_data/ct/volume-27.nii')
reader.Update()

mapper = vtk.vtkGPUVolumeRayCastMapper()
mapper.SetInputData(reader.GetOutput())

volume = vtk.vtkVolume()
volume.SetMapper(mapper)

property = vtk.vtkVolumeProperty()

popacity = vtk.vtkPiecewiseFunction()
popacity.AddPoint(1000, 0.0)
popacity.AddPoint(4000, 0.68)
popacity.AddPoint(7000, 0.83)

color = vtk.vtkColorTransferFunction()
color.AddHSVPoint(1000, 0.042, 0.73, 0.55)
color.AddHSVPoint(2500, 0.042, 0.73, 0.55, 0.5, 0.92)
color.AddHSVPoint(4000, 0.088, 0.67, 0.88)
color.AddHSVPoint(5500, 0.088, 0.67, 0.88, 0.33, 0.45)
color.AddHSVPoint(7000, 0.95, 0.063, 1.0)

property.SetColor(color)
property.SetScalarOpacity(popacity)
property.ShadeOn()
property.SetInterpolationTypeToLinear()
property.SetShade(0, 1)
property.SetDiffuse(0.9)
property.SetAmbient(0.1)
property.SetSpecular(0.2)
property.SetSpecularPower(10.0)
property.SetComponentWeight(0, 1)
property.SetDisableGradientOpacity(1)
property.DisableGradientOpacityOn()
property.SetScalarOpacityUnitDistance(0.891927)

volume.SetProperty(property)

ren = vtk.vtkRenderer()
ren.AddActor(volume)
ren.SetBackground(0.1, 0.2, 0.4)

renWin = vtk.vtkRenderWindow()
renWin.AddRenderer(ren)

iren = vtk.vtkRenderWindowInteractor()
iren.SetRenderWindow(renWin)
renWin.SetSize(600, 600)
renWin.Render()
iren.Start()

读取nii医学图像格式图片:

## 方法一
itk_img = sitk.ReadImage('./fixed_data/ct/volume-27.nii')
img = sitk.GetArrayFromImage(itk_img)
print(img.shape)  # (155, 240, 240) 表示各个维度的切片数量

Monai框架语法:

官方文档:Project MONAI — MONAI 1.0.1 Documentation

推荐文章:使用MONAI深度学习框架进行3D图像空间变换_不入流儿的博客-CSDN博客_monai框架

MONAI(3)—一文看懂各种Transform用法(上)_Tina姐的博客-CSDN博客

MONAI(4)—一文看懂各种Transform用法(下)_Tina姐的博客-CSDN博客

load_decathlon_datalist:

使用 load_decathlon_datalist (MONAI)快速加载JSON数据_Tina姐的博客-CSDN博客

CacheDataset:

monai.data.CacheDataset vs monai.data.Dataset_Tina姐的博客-CSDN博客

sliding_window_inference:

医学图像分割结果保存_老油条666的博客-CSDN博客

LoadImage:(加载图片)

from monai.transforms import LoadImage, LoadImageD

dict_loader = LoadImage(dtype=np.float32, image_only=True)
data_dict = dict_loader("../Coronary_Segmentation_deep_learning/dataset/10423186/img.nii.gz")
print(data_dict.shape)
## 此时的data_dict是Tensor类型的

LoadImageD:(加载图片)

from monai.transforms import LoadImage, LoadImageD

dict_loader = LoadImaged(keys=("image", "label"), image_only=False)

# data_dict = loader(字典列表[0])
data_dict = dict_loader({"image": "../Coronary_Segmentation_deep_learning/dataset/10423186/img.nii.gz", 
                         "label": "../Coronary_Segmentation_deep_learning/dataset/10423186/img.nii.gz"})

SKImage的官网:

Module: filters — skimage v0.19.2 docs

skimgae的Frangi滤波:

from skimage.filters import ridges

path = 'D:/test/Python/GAN/SegAN-master/VOC/head/images/001.jpg'
img = cv2.imread(path, cv2.IMREAD_GRAYSCALE)
cv2.imshow('1',img)
cv2.waitKey(0)
cv2.destroyAllWindows()

'''
‘## 注意这里要转换float64格式,不然不显示图像’
'''
img = img.astype(np.float64) 

img = ridges.frangi(img, sigmas=range(1, 2, 2), black_ridges=False)

猜你喜欢

转载自blog.csdn.net/qq_42792802/article/details/127333521