Table of contents
4.3 save_predictions_as_imgs function
5.1 About imported library files
5.3 train_fn trains an epoch function
6.2 Loading pre-trained weights
Project download address: Segmentation of unet network based on CARVANA dataset
1 Introduction
The directory structure of the project is as follows:
- Data stores training data (5056) + verification data (32)
- saved_val_images stores the results of the network segmentation validation set
CARVANA data:
The corresponding segmentation label:
2. UNET network
UNET is named because the appearance of the network is a U shape. The left side of the network is the downsampling part, and the right side is the upsampling part.
For details, please refer to previous articles: UNET
The construction of the unet network here is different from the previous ones. They all implement unet, but the methods are different, and both can be used.
import torch.nn as nn
import torch
import torchvision.transforms.functional as TF
# 搭建 unet 网络
class DoubleConv(nn.Module): # 连续两次卷积
def __init__(self, in_channels, out_channels):
super(DoubleConv, self).__init__()
self.double_conv = nn.Sequential(
nn.Conv2d(in_channels, out_channels, kernel_size=3, padding=1,stride=1, bias=False),
nn.BatchNorm2d(out_channels),
nn.ReLU(inplace=True),
nn.Conv2d(out_channels, out_channels, kernel_size=3, padding=1, stride=1,bias=False),
nn.BatchNorm2d(out_channels),
nn.ReLU(inplace=True)
)
def forward(self, x):
x = self.double_conv(x)
return x
class UNet(nn.Module):
def __init__(self,in_channels=3,out_channels=1,features=[64,128,256,512]): # features 存放channel数
super(UNet, self).__init__()
self.ups = nn.ModuleList()
self.downs = nn.ModuleList()
self.pool = nn.MaxPool2d(kernel_size=2,stride=2)
# down sampling part of unet
for feature in features:
self.downs.append(DoubleConv(in_channels,feature))
in_channels = feature
# up sampling part of unet
for feature in reversed(features):
self.ups.append(
nn.ConvTranspose2d(feature*2,feature,kernel_size=2,stride=2)
)
self.ups.append(DoubleConv(feature*2,feature))
# bottom part of unet
self.bottleneck = DoubleConv(features[-1],features[-1]*2)
# out layer part of unet
self.final_conv = nn.Conv2d(features[0],out_channels,kernel_size=1)
def forward(self,x):
skip_connections = [] # 尺度融合
# down sampling
for down in self.downs:
x = down(x)
skip_connections.append(x)
x = self.pool(x)
x = self.bottleneck(x)
skip_connections = skip_connections[::-1]
# down sampling
for idx in range(0,len(self.ups),2): # self.ups 包含了转置卷积 + DoubleConv
x = self.ups[idx](x)
skip_connection = skip_connections[idx //2]
if x.shape != skip_connection.shape: # 保证任意输入size
x = TF.resize(x,size = skip_connection.shape[2:])
concat_skip = torch.cat((skip_connection,x),dim = 1) # 转置卷积
x = self.ups[idx+1](concat_skip) # DoubleConv
x = self.final_conv(x)
return x
# if __name__ == '__main__':
# x = torch.rand((3,1,159,159))
# model = UNet(in_channels=1,out_channels=1)
# out = model(x)
# assert x.shape == out.shape
3. dataset data loading
Similar to the previous dataset, but with some small gaps
For details, please refer to the previous article: dataset
Here is the code for the dataset:
import os
from PIL import Image
from torch.utils.data import Dataset
import numpy as np
# 数据加载
class CarvanaDataset(Dataset):
def __init__(self,image_dir,mask_dir,transform = None):
self.image_dir = image_dir # 训练数据的路径
self.mask_dir = mask_dir # label 的路径
self.transform = transform
self.images = os.listdir(image_dir) # 文件夹中的所有文件
def __len__(self):
return len(self.images)
def __getitem__(self, index):
img_path = os.path.join(self.image_dir,self.images[index]) # 拼接成各个数据的路径
mask_path = os.path.join(self.mask_dir,self.images[index].replace('.jpg','_mask.gif')) # label只是后缀的名字不同,替换掉即可
image = np.array(Image.open(img_path).convert('RGB'))
mask = np.array(Image.open(mask_path).convert("L"),dtype=np.float32) # 'L' 为灰度图
mask[mask == 255.0] = 1.0 # 变成二值图
if self.transform is not None:
augmentations = self.transform(image = image,mask = mask)
image = augmentations['image']
mask = augmentations['mask']
return image,mask
It should be noted that the label should be set as a binary image here
The label here is a binary image, the foreground pixel is 255, and the background is 0
I had a question before, why the label is a grayscale image, not a binary image, like this
Later, it was found that it might be a display problem. After zooming in, it was found that the label was a binary image.
4. utils tool module
In order to avoid the code of the main program being too complicated, the required repeated parts are encapsulated into the utils module, and the following three parts are mainly implemented here
- get_loaders # function to load data
- check_accuracy # Verify the accuracy of the model
- save_predictions_as_imgs # Save the segmented image of the model on the validation set
4.1 get_loaders function
The part of loading data is relatively simple, no different from the previous ones, here is just a simple encapsulation
Parameters that get_loader needs to pass:
- train_dir : the image address of the training set
- train_mask_dir : mask address of the training set
- val_dir : the image address of the validation set
- val_mask_dir : mask address of validation set
- batch_size :batch的size
- transform : preprocessing
- num_workser: number of threads, windows needs to be set to 0, or needs to be (if __name__ == '__main__': # so that num_workers != 0 can pass)
The return value of get_loader is the training image and label, and the verified image and label
4.2 check_accuracy function
check_accuracy is a function to verify the accuracy of the model. It needs to pass in the image and label of the loader verification set, the model is the network used for verification, and the device is the device that the network runs on.
Because the binary image does not have a channel dimension, the label needs to be increased by one dimension
The output of the network is passed through the sigmoid function, and the pixels greater than 0.5 are mapped to the foreground pixels, and the pixels smaller than 0.5 are mapped to the background pixels.
DICE is defined as follows:
dice_score += ( 2*(pred * y).sum() ) / ((pred + y).sum() + 1e-8 )
4.3 save_predictions_as_imgs function
The function to save the image is shown in the figure:
- The tensor is transformed into a numpy array to save the picture. This process is cumbersome. Pytorch provides the save_image() function, which can directly save the tensor as a picture. If the tensor is on cuda, it will also be moved to the CPU for saving.
- In the deep learning model, the save_image() function in torchvision.utils is generally used to save images, but this method can only save RGB color images. If the output of the network is a single-channel grayscale image, the function will still output three channels, the value of each channel is the same, that is, "pseudo-grayscale image", no difference can be seen visually, but the memory occupied by the image is twice as large as normal.
4.4 Complete code
utils are as follows:
import torch
import torchvision
from dataset import CarvanaDataset
from torch.utils.data import DataLoader
# 加载数据的参数函数
def get_loaders(train_dir,train_mask_dir,val_dir,val_mask_dir,batch_size,train_transform,val_transform,num_workers):
# 加载训练集
train_set = CarvanaDataset(image_dir=train_dir,mask_dir=train_mask_dir,transform=train_transform)
train_loader = DataLoader(train_set,batch_size=batch_size,num_workers=num_workers,shuffle=True)
# 加载验证集
val_set = CarvanaDataset(image_dir=val_dir,mask_dir=val_mask_dir,transform=val_transform)
val_loader = DataLoader(val_set,batch_size=batch_size,num_workers=num_workers,shuffle=False)
return train_loader,val_loader
# 检验精度
def check_accuracy(loader,model,device):
num_correct = 0
num_pixels = 0
dice_score = 0
model.eval() # 测试模式
with torch.no_grad():
for x,y in loader:
x = x.to(device)
y = y.to(device).unsqueeze(1) # add label 中的channel维度
pred = torch.sigmoid(model(x))
pred = (pred > 0.5 ).float() # 转化为二值图像
num_correct += (pred == y).sum() # prediction 和 label中相同像素点的个数
num_pixels += torch.numel(pred) # 统计 y 中像素点的个数
dice_score += ( 2*(pred * y).sum() ) / ((pred + y).sum() + 1e-8 )
# 预测像素点正确的个数 / label
print(
f'Got {num_correct}/{num_pixels} with accuracy {num_correct/num_pixels*100:.2f}%'
)
# Dice 指标
print(f'Dice score : {dice_score / len(loader)}')
model.train()
# show 预测图片
def save_predictions_as_imgs(loader,model,device,folder = './saved_val_images/'):
print('------>Loading predictions')
model.eval()
for idx,(x,y) in enumerate(loader):
x = x.to(device=device)
with torch.no_grad():
pred = torch.sigmoid(model(x))
pred = (pred > 0.5).float()
torchvision.utils.save_image(pred, f'{folder}/pred_{idx}.png') # 保存预测图像
torchvision.utils.save_image(y.unsqueeze(1),f'{folder}/label_{idx}.png') # 保存label图像
model.train()
5. train function
The train function is used to train the main function of the network
OMP: Error #15: Initializing libiomp5md.dll, but found libiomp5md.dll already initialized.
OMP: Hint This means that multiple copies of the OpenMP runtime have been linked into the program. That is dangerous, since it can degrade performance or cause incorrect results. The best thing to do is to ensure that only a single OpenMP runtime is linked into the process, e.g. by avoiding static linking of the OpenMP runtime in any library. As an unsafe, unsupported, undocumented workaround you can set the environment variable KMP_DUPLICATE_LIB_OK=TRUE to allow the program to continue to execute, but that may cause crashes or silently produce incorrect results. For more information, please see http://www.intel.com/software/products/support/.
When running the train function here, such an error will be reported. The simple way is to add this at the front end of the code:
import os
os.environ['KMP_DUPLICATE_LIB_OK']='True'
5.1 About imported library files
import torch
import albumentations as A # 图像增强库
from albumentations.pytorch import ToTensorV2 # 只会[h, w, c] -> [c, h, w],不会将数据归一化到[0, 1]
from tqdm import tqdm # 进度条提示模块
import torch.nn as nn
from unet import UNet
import torch.optim as optim
# 自定义的模块
from utils import (
get_loaders, # 加载数据
check_accuracy, # 验证准确率
save_predictions_as_imgs, # 预测图片
)
Some library files here are different from the previous ones, and they are all annotated
5.2 Setting hyperparameters
What needs to be noted here is LOAD_MODEL, which can be considered as a switch for whether to use pre-trained weights
If the network has been trained before and there is a saved weight file, when LOAD_MODEL is set to TRUE, the previously trained weight file will be loaded, and then the learning rate can be adjusted appropriately to continue training
5.3 train_fn trains an epoch function
code show as below
5.4 main function
Define the preprocessing of the training data:
Define the preprocessing of validation data:
Create a model:
Get the training and validation data from the get_loader function:
Whether to load the pre-trained model:
Train model + save parameters + display prediction results:
5.5 Complete code
as follows:
import os
os.environ['KMP_DUPLICATE_LIB_OK']='True'
import torch
import albumentations as A # 图像增强库
from albumentations.pytorch import ToTensorV2 # 只会[h, w, c] -> [c, h, w],不会将数据归一化到[0, 1]
from tqdm import tqdm # 进度条提示模块
import torch.nn as nn
from unet import UNet
import torch.optim as optim
# 自定义的模块
from utils import (
get_loaders, # 加载数据
check_accuracy, # 验证准确率
save_predictions_as_imgs, # 预测图片
)
# 设置超参数
LEARNING_RATE = 1e-4
DEVICE = 'cuda' if torch.cuda.is_available() else 'cpu'
BATCH_SIZE =16
NUM_EPOCHS = 2 # epoch
NUM_WORKERS = 5
IMAGE_HEIGHT = 160
IMAGE_WIDTH = 240
LOAD_MODEL = False
TRAIN_IMG_DIR = './data/train_images'
TRAIN_MASK_DIR = './data/train_masks'
VAL_IMG_DIR = './data/val_images'
VAL_MASK_DIR = './data/val_masks'
# 训练函数,一个epoch
def train_fn(loader,model,optimizer,loss_fn,scaler):
loop = tqdm(loader)
for batch_idx,(img,label) in enumerate(loop):
img = img.to(device=DEVICE)
label = label.float().unsqueeze(1).to(DEVICE) # 增加channel维度
# forward
with torch.cuda.amp.autocast(): # 采用混合精度训练,不同的layer用不同的精度,达到加速训练的目的
predictions = model(img) # 网络输出
loss = loss_fn(predictions,label)
# backward
optimizer.zero_grad()
scaler.scale(loss).backward()
scaler.step(optimizer)
scaler.update()
# update tqdm loop
loop.set_postfix(loss = loss.item())
def main():
# 训练数据预处理
train_transforms = A.Compose(
[
A.Resize(height=IMAGE_HEIGHT,width=IMAGE_WIDTH),
A.Rotate(limit=35,p=0.5), # (-limit,limit)随机旋转,p=0.5 50% 概率随机旋转
A.HorizontalFlip(p=0.5), # 50% 概率水平翻转:沿着竖轴
A.VerticalFlip(p=0.1), # 10% 概率竖直翻转:沿着水平轴
A.Normalize( # img = (img - mean * max_pixel_value) / (std * max_pixel_value)
mean=[0.0,0.0,0.0],
std=[1.0,1.0,1.0],
max_pixel_value= 255.0
),
ToTensorV2(), # [h, w, c] -> [c, h, w]
]
)
# 验证数据预处理
val_transforms = A.Compose(
[
A.Resize(height=IMAGE_HEIGHT,width=IMAGE_WIDTH),
A.Normalize(
mean=[0.0,0.0,0.0],
std=[1.0,1.0,1.0],
max_pixel_value= 255.0
),
ToTensorV2(),
]
)
# 实例化 UNet 模型 + loss + optimizer
model = UNet(in_channels=3,out_channels=1).to(DEVICE)
loss_fn = nn.BCEWithLogitsLoss() # 二元交叉熵 + sigmoid
optimizer = optim.Adam(model.parameters(),lr=LEARNING_RATE)
# 获取数据集
# train_loader:train_images,train_masks
# val_loader:val_images,val_masks
train_loader,val_loader = get_loaders(
TRAIN_IMG_DIR,
TRAIN_MASK_DIR,
VAL_IMG_DIR,
VAL_MASK_DIR,
BATCH_SIZE,
train_transforms,
val_transforms,
NUM_WORKERS,
)
# 加载预训练权重
if LOAD_MODEL:
print('Pretrained:')
model.load_state_dict(torch.load('unet.pth'))
check_accuracy(val_loader,model,device=DEVICE)
print('------>Loading pretrained model successfully!!')
scaler = torch.cuda.amp.GradScaler() # 采用混合精度,加速训练
for epoch in range(NUM_EPOCHS):
print('Epoch:', epoch + 1)
train_fn(train_loader,model,optimizer,loss_fn,scaler) # 训练一个 epoch
# check accuracy
check_accuracy(val_loader,model,device=DEVICE)
# save model
print('------>Saving checkpoint')
torch.save(model.state_dict(),'unet.pth')
# print some examples to a folder
save_predictions_as_imgs(val_loader,model,folder='saved_val_images/',device=DEVICE)
if __name__ == '__main__': # 这样num_workers != 0 才可以通过
main()
print(' training over!!!! ')
6. Display
6.1 Network Training
The network trained two epoch results
Here 316 is because of samples/batch_size: 5056 / 16 = 316
6.2 Loading pre-trained weight training
LOAD_MODEL = True
6.3 Result display
Network predictions:
Real label:
Network predictions:
Real label: